diff --git a/.venv/lib/python3.13/site-packages/sympy/algebras/__init__.py b/.venv/lib/python3.13/site-packages/sympy/algebras/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..58013d2e0377a016d1a21fbf21c344ee76765189 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/algebras/__init__.py @@ -0,0 +1,3 @@ +from .quaternion import Quaternion + +__all__ = ["Quaternion",] diff --git a/.venv/lib/python3.13/site-packages/sympy/algebras/quaternion.py b/.venv/lib/python3.13/site-packages/sympy/algebras/quaternion.py new file mode 100644 index 0000000000000000000000000000000000000000..1bf4b363545128e50294e825461106ef9b41990d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/algebras/quaternion.py @@ -0,0 +1,1666 @@ +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.core.relational import is_eq +from sympy.functions.elementary.complexes import (conjugate, im, re, sign) +from sympy.functions.elementary.exponential import (exp, log as ln) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (acos, asin, atan2) +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.simplify.trigsimp import trigsimp +from sympy.integrals.integrals import integrate +from sympy.matrices.dense import MutableDenseMatrix as Matrix +from sympy.core.sympify import sympify, _sympify +from sympy.core.expr import Expr +from sympy.core.logic import fuzzy_not, fuzzy_or +from sympy.utilities.misc import as_int + +from mpmath.libmp.libmpf import prec_to_dps + + +def _check_norm(elements, norm): + """validate if input norm is consistent""" + if norm is not None and norm.is_number: + if norm.is_positive is False: + raise ValueError("Input norm must be positive.") + + numerical = all(i.is_number and i.is_real is True for i in elements) + if numerical and is_eq(norm**2, sum(i**2 for i in elements)) is False: + raise ValueError("Incompatible value for norm.") + + +def _is_extrinsic(seq): + """validate seq and return True if seq is lowercase and False if uppercase""" + if type(seq) != str: + raise ValueError('Expected seq to be a string.') + if len(seq) != 3: + raise ValueError("Expected 3 axes, got `{}`.".format(seq)) + + intrinsic = seq.isupper() + extrinsic = seq.islower() + if not (intrinsic or extrinsic): + raise ValueError("seq must either be fully uppercase (for extrinsic " + "rotations), or fully lowercase, for intrinsic " + "rotations).") + + i, j, k = seq.lower() + if (i == j) or (j == k): + raise ValueError("Consecutive axes must be different") + + bad = set(seq) - set('xyzXYZ') + if bad: + raise ValueError("Expected axes from `seq` to be from " + "['x', 'y', 'z'] or ['X', 'Y', 'Z'], " + "got {}".format(''.join(bad))) + + return extrinsic + + +class Quaternion(Expr): + """Provides basic quaternion operations. + Quaternion objects can be instantiated as ``Quaternion(a, b, c, d)`` + as in $q = a + bi + cj + dk$. + + Parameters + ========== + + norm : None or number + Pre-defined quaternion norm. If a value is given, Quaternion.norm + returns this pre-defined value instead of calculating the norm + + Examples + ======== + + >>> from sympy import Quaternion + >>> q = Quaternion(1, 2, 3, 4) + >>> q + 1 + 2*i + 3*j + 4*k + + Quaternions over complex fields can be defined as: + + >>> from sympy import Quaternion + >>> from sympy import symbols, I + >>> x = symbols('x') + >>> q1 = Quaternion(x, x**3, x, x**2, real_field = False) + >>> q2 = Quaternion(3 + 4*I, 2 + 5*I, 0, 7 + 8*I, real_field = False) + >>> q1 + x + x**3*i + x*j + x**2*k + >>> q2 + (3 + 4*I) + (2 + 5*I)*i + 0*j + (7 + 8*I)*k + + Defining symbolic unit quaternions: + + >>> from sympy import Quaternion + >>> from sympy.abc import w, x, y, z + >>> q = Quaternion(w, x, y, z, norm=1) + >>> q + w + x*i + y*j + z*k + >>> q.norm() + 1 + + References + ========== + + .. [1] https://www.euclideanspace.com/maths/algebra/realNormedAlgebra/quaternions/ + .. [2] https://en.wikipedia.org/wiki/Quaternion + + """ + _op_priority = 11.0 + + is_commutative = False + + def __new__(cls, a=0, b=0, c=0, d=0, real_field=True, norm=None): + a, b, c, d = map(sympify, (a, b, c, d)) + + if any(i.is_commutative is False for i in [a, b, c, d]): + raise ValueError("arguments have to be commutative") + obj = super().__new__(cls, a, b, c, d) + obj._real_field = real_field + obj.set_norm(norm) + return obj + + def set_norm(self, norm): + """Sets norm of an already instantiated quaternion. + + Parameters + ========== + + norm : None or number + Pre-defined quaternion norm. If a value is given, Quaternion.norm + returns this pre-defined value instead of calculating the norm + + Examples + ======== + + >>> from sympy import Quaternion + >>> from sympy.abc import a, b, c, d + >>> q = Quaternion(a, b, c, d) + >>> q.norm() + sqrt(a**2 + b**2 + c**2 + d**2) + + Setting the norm: + + >>> q.set_norm(1) + >>> q.norm() + 1 + + Removing set norm: + + >>> q.set_norm(None) + >>> q.norm() + sqrt(a**2 + b**2 + c**2 + d**2) + + """ + norm = sympify(norm) + _check_norm(self.args, norm) + self._norm = norm + + @property + def a(self): + return self.args[0] + + @property + def b(self): + return self.args[1] + + @property + def c(self): + return self.args[2] + + @property + def d(self): + return self.args[3] + + @property + def real_field(self): + return self._real_field + + @property + def product_matrix_left(self): + r"""Returns 4 x 4 Matrix equivalent to a Hamilton product from the + left. This can be useful when treating quaternion elements as column + vectors. Given a quaternion $q = a + bi + cj + dk$ where a, b, c and d + are real numbers, the product matrix from the left is: + + .. math:: + + M = \begin{bmatrix} a &-b &-c &-d \\ + b & a &-d & c \\ + c & d & a &-b \\ + d &-c & b & a \end{bmatrix} + + Examples + ======== + + >>> from sympy import Quaternion + >>> from sympy.abc import a, b, c, d + >>> q1 = Quaternion(1, 0, 0, 1) + >>> q2 = Quaternion(a, b, c, d) + >>> q1.product_matrix_left + Matrix([ + [1, 0, 0, -1], + [0, 1, -1, 0], + [0, 1, 1, 0], + [1, 0, 0, 1]]) + + >>> q1.product_matrix_left * q2.to_Matrix() + Matrix([ + [a - d], + [b - c], + [b + c], + [a + d]]) + + This is equivalent to: + + >>> (q1 * q2).to_Matrix() + Matrix([ + [a - d], + [b - c], + [b + c], + [a + d]]) + """ + return Matrix([ + [self.a, -self.b, -self.c, -self.d], + [self.b, self.a, -self.d, self.c], + [self.c, self.d, self.a, -self.b], + [self.d, -self.c, self.b, self.a]]) + + @property + def product_matrix_right(self): + r"""Returns 4 x 4 Matrix equivalent to a Hamilton product from the + right. This can be useful when treating quaternion elements as column + vectors. Given a quaternion $q = a + bi + cj + dk$ where a, b, c and d + are real numbers, the product matrix from the left is: + + .. math:: + + M = \begin{bmatrix} a &-b &-c &-d \\ + b & a & d &-c \\ + c &-d & a & b \\ + d & c &-b & a \end{bmatrix} + + + Examples + ======== + + >>> from sympy import Quaternion + >>> from sympy.abc import a, b, c, d + >>> q1 = Quaternion(a, b, c, d) + >>> q2 = Quaternion(1, 0, 0, 1) + >>> q2.product_matrix_right + Matrix([ + [1, 0, 0, -1], + [0, 1, 1, 0], + [0, -1, 1, 0], + [1, 0, 0, 1]]) + + Note the switched arguments: the matrix represents the quaternion on + the right, but is still considered as a matrix multiplication from the + left. + + >>> q2.product_matrix_right * q1.to_Matrix() + Matrix([ + [ a - d], + [ b + c], + [-b + c], + [ a + d]]) + + This is equivalent to: + + >>> (q1 * q2).to_Matrix() + Matrix([ + [ a - d], + [ b + c], + [-b + c], + [ a + d]]) + """ + return Matrix([ + [self.a, -self.b, -self.c, -self.d], + [self.b, self.a, self.d, -self.c], + [self.c, -self.d, self.a, self.b], + [self.d, self.c, -self.b, self.a]]) + + def to_Matrix(self, vector_only=False): + """Returns elements of quaternion as a column vector. + By default, a ``Matrix`` of length 4 is returned, with the real part as the + first element. + If ``vector_only`` is ``True``, returns only imaginary part as a Matrix of + length 3. + + Parameters + ========== + + vector_only : bool + If True, only imaginary part is returned. + Default value: False + + Returns + ======= + + Matrix + A column vector constructed by the elements of the quaternion. + + Examples + ======== + + >>> from sympy import Quaternion + >>> from sympy.abc import a, b, c, d + >>> q = Quaternion(a, b, c, d) + >>> q + a + b*i + c*j + d*k + + >>> q.to_Matrix() + Matrix([ + [a], + [b], + [c], + [d]]) + + + >>> q.to_Matrix(vector_only=True) + Matrix([ + [b], + [c], + [d]]) + + """ + if vector_only: + return Matrix(self.args[1:]) + else: + return Matrix(self.args) + + @classmethod + def from_Matrix(cls, elements): + """Returns quaternion from elements of a column vector`. + If vector_only is True, returns only imaginary part as a Matrix of + length 3. + + Parameters + ========== + + elements : Matrix, list or tuple of length 3 or 4. If length is 3, + assume real part is zero. + Default value: False + + Returns + ======= + + Quaternion + A quaternion created from the input elements. + + Examples + ======== + + >>> from sympy import Quaternion + >>> from sympy.abc import a, b, c, d + >>> q = Quaternion.from_Matrix([a, b, c, d]) + >>> q + a + b*i + c*j + d*k + + >>> q = Quaternion.from_Matrix([b, c, d]) + >>> q + 0 + b*i + c*j + d*k + + """ + length = len(elements) + if length != 3 and length != 4: + raise ValueError("Input elements must have length 3 or 4, got {} " + "elements".format(length)) + + if length == 3: + return Quaternion(0, *elements) + else: + return Quaternion(*elements) + + @classmethod + def from_euler(cls, angles, seq): + """Returns quaternion equivalent to rotation represented by the Euler + angles, in the sequence defined by ``seq``. + + Parameters + ========== + + angles : list, tuple or Matrix of 3 numbers + The Euler angles (in radians). + seq : string of length 3 + Represents the sequence of rotations. + For extrinsic rotations, seq must be all lowercase and its elements + must be from the set ``{'x', 'y', 'z'}`` + For intrinsic rotations, seq must be all uppercase and its elements + must be from the set ``{'X', 'Y', 'Z'}`` + + Returns + ======= + + Quaternion + The normalized rotation quaternion calculated from the Euler angles + in the given sequence. + + Examples + ======== + + >>> from sympy import Quaternion + >>> from sympy import pi + >>> q = Quaternion.from_euler([pi/2, 0, 0], 'xyz') + >>> q + sqrt(2)/2 + sqrt(2)/2*i + 0*j + 0*k + + >>> q = Quaternion.from_euler([0, pi/2, pi] , 'zyz') + >>> q + 0 + (-sqrt(2)/2)*i + 0*j + sqrt(2)/2*k + + >>> q = Quaternion.from_euler([0, pi/2, pi] , 'ZYZ') + >>> q + 0 + sqrt(2)/2*i + 0*j + sqrt(2)/2*k + + """ + + if len(angles) != 3: + raise ValueError("3 angles must be given.") + + extrinsic = _is_extrinsic(seq) + i, j, k = seq.lower() + + # get elementary basis vectors + ei = [1 if n == i else 0 for n in 'xyz'] + ej = [1 if n == j else 0 for n in 'xyz'] + ek = [1 if n == k else 0 for n in 'xyz'] + + # calculate distinct quaternions + qi = cls.from_axis_angle(ei, angles[0]) + qj = cls.from_axis_angle(ej, angles[1]) + qk = cls.from_axis_angle(ek, angles[2]) + + if extrinsic: + return trigsimp(qk * qj * qi) + else: + return trigsimp(qi * qj * qk) + + def to_euler(self, seq, angle_addition=True, avoid_square_root=False): + r"""Returns Euler angles representing same rotation as the quaternion, + in the sequence given by ``seq``. This implements the method described + in [1]_. + + For degenerate cases (gymbal lock cases), the third angle is + set to zero. + + Parameters + ========== + + seq : string of length 3 + Represents the sequence of rotations. + For extrinsic rotations, seq must be all lowercase and its elements + must be from the set ``{'x', 'y', 'z'}`` + For intrinsic rotations, seq must be all uppercase and its elements + must be from the set ``{'X', 'Y', 'Z'}`` + + angle_addition : bool + When True, first and third angles are given as an addition and + subtraction of two simpler ``atan2`` expressions. When False, the + first and third angles are each given by a single more complicated + ``atan2`` expression. This equivalent expression is given by: + + .. math:: + + \operatorname{atan_2} (b,a) \pm \operatorname{atan_2} (d,c) = + \operatorname{atan_2} (bc\pm ad, ac\mp bd) + + Default value: True + + avoid_square_root : bool + When True, the second angle is calculated with an expression based + on ``acos``, which is slightly more complicated but avoids a square + root. When False, second angle is calculated with ``atan2``, which + is simpler and can be better for numerical reasons (some + numerical implementations of ``acos`` have problems near zero). + Default value: False + + + Returns + ======= + + Tuple + The Euler angles calculated from the quaternion + + Examples + ======== + + >>> from sympy import Quaternion + >>> from sympy.abc import a, b, c, d + >>> euler = Quaternion(a, b, c, d).to_euler('zyz') + >>> euler + (-atan2(-b, c) + atan2(d, a), + 2*atan2(sqrt(b**2 + c**2), sqrt(a**2 + d**2)), + atan2(-b, c) + atan2(d, a)) + + + References + ========== + + .. [1] https://doi.org/10.1371/journal.pone.0276302 + + """ + if self.is_zero_quaternion(): + raise ValueError('Cannot convert a quaternion with norm 0.') + + angles = [0, 0, 0] + + extrinsic = _is_extrinsic(seq) + i, j, k = seq.lower() + + # get index corresponding to elementary basis vectors + i = 'xyz'.index(i) + 1 + j = 'xyz'.index(j) + 1 + k = 'xyz'.index(k) + 1 + + if not extrinsic: + i, k = k, i + + # check if sequence is symmetric + symmetric = i == k + if symmetric: + k = 6 - i - j + + # parity of the permutation + sign = (i - j) * (j - k) * (k - i) // 2 + + # permutate elements + elements = [self.a, self.b, self.c, self.d] + a = elements[0] + b = elements[i] + c = elements[j] + d = elements[k] * sign + + if not symmetric: + a, b, c, d = a - c, b + d, c + a, d - b + + if avoid_square_root: + if symmetric: + n2 = self.norm()**2 + angles[1] = acos((a * a + b * b - c * c - d * d) / n2) + else: + n2 = 2 * self.norm()**2 + angles[1] = asin((c * c + d * d - a * a - b * b) / n2) + else: + angles[1] = 2 * atan2(sqrt(c * c + d * d), sqrt(a * a + b * b)) + if not symmetric: + angles[1] -= S.Pi / 2 + + # Check for singularities in numerical cases + case = 0 + if is_eq(c, S.Zero) and is_eq(d, S.Zero): + case = 1 + if is_eq(a, S.Zero) and is_eq(b, S.Zero): + case = 2 + + if case == 0: + if angle_addition: + angles[0] = atan2(b, a) + atan2(d, c) + angles[2] = atan2(b, a) - atan2(d, c) + else: + angles[0] = atan2(b*c + a*d, a*c - b*d) + angles[2] = atan2(b*c - a*d, a*c + b*d) + + else: # any degenerate case + angles[2 * (not extrinsic)] = S.Zero + if case == 1: + angles[2 * extrinsic] = 2 * atan2(b, a) + else: + angles[2 * extrinsic] = 2 * atan2(d, c) + angles[2 * extrinsic] *= (-1 if extrinsic else 1) + + # for Tait-Bryan angles + if not symmetric: + angles[0] *= sign + + if extrinsic: + return tuple(angles[::-1]) + else: + return tuple(angles) + + @classmethod + def from_axis_angle(cls, vector, angle): + """Returns a rotation quaternion given the axis and the angle of rotation. + + Parameters + ========== + + vector : tuple of three numbers + The vector representation of the given axis. + angle : number + The angle by which axis is rotated (in radians). + + Returns + ======= + + Quaternion + The normalized rotation quaternion calculated from the given axis and the angle of rotation. + + Examples + ======== + + >>> from sympy import Quaternion + >>> from sympy import pi, sqrt + >>> q = Quaternion.from_axis_angle((sqrt(3)/3, sqrt(3)/3, sqrt(3)/3), 2*pi/3) + >>> q + 1/2 + 1/2*i + 1/2*j + 1/2*k + + """ + (x, y, z) = vector + norm = sqrt(x**2 + y**2 + z**2) + (x, y, z) = (x / norm, y / norm, z / norm) + s = sin(angle * S.Half) + a = cos(angle * S.Half) + b = x * s + c = y * s + d = z * s + + # note that this quaternion is already normalized by construction: + # c^2 + (s*x)^2 + (s*y)^2 + (s*z)^2 = c^2 + s^2*(x^2 + y^2 + z^2) = c^2 + s^2 * 1 = c^2 + s^2 = 1 + # so, what we return is a normalized quaternion + + return cls(a, b, c, d) + + @classmethod + def from_rotation_matrix(cls, M): + """Returns the equivalent quaternion of a matrix. The quaternion will be normalized + only if the matrix is special orthogonal (orthogonal and det(M) = 1). + + Parameters + ========== + + M : Matrix + Input matrix to be converted to equivalent quaternion. M must be special + orthogonal (orthogonal and det(M) = 1) for the quaternion to be normalized. + + Returns + ======= + + Quaternion + The quaternion equivalent to given matrix. + + Examples + ======== + + >>> from sympy import Quaternion + >>> from sympy import Matrix, symbols, cos, sin, trigsimp + >>> x = symbols('x') + >>> M = Matrix([[cos(x), -sin(x), 0], [sin(x), cos(x), 0], [0, 0, 1]]) + >>> q = trigsimp(Quaternion.from_rotation_matrix(M)) + >>> q + sqrt(2)*sqrt(cos(x) + 1)/2 + 0*i + 0*j + sqrt(2 - 2*cos(x))*sign(sin(x))/2*k + + """ + + absQ = M.det()**Rational(1, 3) + + a = sqrt(absQ + M[0, 0] + M[1, 1] + M[2, 2]) / 2 + b = sqrt(absQ + M[0, 0] - M[1, 1] - M[2, 2]) / 2 + c = sqrt(absQ - M[0, 0] + M[1, 1] - M[2, 2]) / 2 + d = sqrt(absQ - M[0, 0] - M[1, 1] + M[2, 2]) / 2 + + b = b * sign(M[2, 1] - M[1, 2]) + c = c * sign(M[0, 2] - M[2, 0]) + d = d * sign(M[1, 0] - M[0, 1]) + + return Quaternion(a, b, c, d) + + def __add__(self, other): + return self.add(other) + + def __radd__(self, other): + return self.add(other) + + def __sub__(self, other): + return self.add(other*-1) + + def __mul__(self, other): + return self._generic_mul(self, _sympify(other)) + + def __rmul__(self, other): + return self._generic_mul(_sympify(other), self) + + def __pow__(self, p): + return self.pow(p) + + def __neg__(self): + return Quaternion(-self.a, -self.b, -self.c, -self.d) + + def __truediv__(self, other): + return self * sympify(other)**-1 + + def __rtruediv__(self, other): + return sympify(other) * self**-1 + + def _eval_Integral(self, *args): + return self.integrate(*args) + + def diff(self, *symbols, **kwargs): + kwargs.setdefault('evaluate', True) + return self.func(*[a.diff(*symbols, **kwargs) for a in self.args]) + + def add(self, other): + """Adds quaternions. + + Parameters + ========== + + other : Quaternion + The quaternion to add to current (self) quaternion. + + Returns + ======= + + Quaternion + The resultant quaternion after adding self to other + + Examples + ======== + + >>> from sympy import Quaternion + >>> from sympy import symbols + >>> q1 = Quaternion(1, 2, 3, 4) + >>> q2 = Quaternion(5, 6, 7, 8) + >>> q1.add(q2) + 6 + 8*i + 10*j + 12*k + >>> q1 + 5 + 6 + 2*i + 3*j + 4*k + >>> x = symbols('x', real = True) + >>> q1.add(x) + (x + 1) + 2*i + 3*j + 4*k + + Quaternions over complex fields : + + >>> from sympy import Quaternion + >>> from sympy import I + >>> q3 = Quaternion(3 + 4*I, 2 + 5*I, 0, 7 + 8*I, real_field = False) + >>> q3.add(2 + 3*I) + (5 + 7*I) + (2 + 5*I)*i + 0*j + (7 + 8*I)*k + + """ + q1 = self + q2 = sympify(other) + + # If q2 is a number or a SymPy expression instead of a quaternion + if not isinstance(q2, Quaternion): + if q1.real_field and q2.is_complex: + return Quaternion(re(q2) + q1.a, im(q2) + q1.b, q1.c, q1.d) + elif q2.is_commutative: + return Quaternion(q1.a + q2, q1.b, q1.c, q1.d) + else: + raise ValueError("Only commutative expressions can be added with a Quaternion.") + + return Quaternion(q1.a + q2.a, q1.b + q2.b, q1.c + q2.c, q1.d + + q2.d) + + def mul(self, other): + """Multiplies quaternions. + + Parameters + ========== + + other : Quaternion or symbol + The quaternion to multiply to current (self) quaternion. + + Returns + ======= + + Quaternion + The resultant quaternion after multiplying self with other + + Examples + ======== + + >>> from sympy import Quaternion + >>> from sympy import symbols + >>> q1 = Quaternion(1, 2, 3, 4) + >>> q2 = Quaternion(5, 6, 7, 8) + >>> q1.mul(q2) + (-60) + 12*i + 30*j + 24*k + >>> q1.mul(2) + 2 + 4*i + 6*j + 8*k + >>> x = symbols('x', real = True) + >>> q1.mul(x) + x + 2*x*i + 3*x*j + 4*x*k + + Quaternions over complex fields : + + >>> from sympy import Quaternion + >>> from sympy import I + >>> q3 = Quaternion(3 + 4*I, 2 + 5*I, 0, 7 + 8*I, real_field = False) + >>> q3.mul(2 + 3*I) + (2 + 3*I)*(3 + 4*I) + (2 + 3*I)*(2 + 5*I)*i + 0*j + (2 + 3*I)*(7 + 8*I)*k + + """ + return self._generic_mul(self, _sympify(other)) + + @staticmethod + def _generic_mul(q1, q2): + """Generic multiplication. + + Parameters + ========== + + q1 : Quaternion or symbol + q2 : Quaternion or symbol + + It is important to note that if neither q1 nor q2 is a Quaternion, + this function simply returns q1 * q2. + + Returns + ======= + + Quaternion + The resultant quaternion after multiplying q1 and q2 + + Examples + ======== + + >>> from sympy import Quaternion + >>> from sympy import Symbol, S + >>> q1 = Quaternion(1, 2, 3, 4) + >>> q2 = Quaternion(5, 6, 7, 8) + >>> Quaternion._generic_mul(q1, q2) + (-60) + 12*i + 30*j + 24*k + >>> Quaternion._generic_mul(q1, S(2)) + 2 + 4*i + 6*j + 8*k + >>> x = Symbol('x', real = True) + >>> Quaternion._generic_mul(q1, x) + x + 2*x*i + 3*x*j + 4*x*k + + Quaternions over complex fields : + + >>> from sympy import I + >>> q3 = Quaternion(3 + 4*I, 2 + 5*I, 0, 7 + 8*I, real_field = False) + >>> Quaternion._generic_mul(q3, 2 + 3*I) + (2 + 3*I)*(3 + 4*I) + (2 + 3*I)*(2 + 5*I)*i + 0*j + (2 + 3*I)*(7 + 8*I)*k + + """ + # None is a Quaternion: + if not isinstance(q1, Quaternion) and not isinstance(q2, Quaternion): + return q1 * q2 + + # If q1 is a number or a SymPy expression instead of a quaternion + if not isinstance(q1, Quaternion): + if q2.real_field and q1.is_complex: + return Quaternion(re(q1), im(q1), 0, 0) * q2 + elif q1.is_commutative: + return Quaternion(q1 * q2.a, q1 * q2.b, q1 * q2.c, q1 * q2.d) + else: + raise ValueError("Only commutative expressions can be multiplied with a Quaternion.") + + # If q2 is a number or a SymPy expression instead of a quaternion + if not isinstance(q2, Quaternion): + if q1.real_field and q2.is_complex: + return q1 * Quaternion(re(q2), im(q2), 0, 0) + elif q2.is_commutative: + return Quaternion(q2 * q1.a, q2 * q1.b, q2 * q1.c, q2 * q1.d) + else: + raise ValueError("Only commutative expressions can be multiplied with a Quaternion.") + + # If any of the quaternions has a fixed norm, pre-compute norm + if q1._norm is None and q2._norm is None: + norm = None + else: + norm = q1.norm() * q2.norm() + + return Quaternion(-q1.b*q2.b - q1.c*q2.c - q1.d*q2.d + q1.a*q2.a, + q1.b*q2.a + q1.c*q2.d - q1.d*q2.c + q1.a*q2.b, + -q1.b*q2.d + q1.c*q2.a + q1.d*q2.b + q1.a*q2.c, + q1.b*q2.c - q1.c*q2.b + q1.d*q2.a + q1.a * q2.d, + norm=norm) + + def _eval_conjugate(self): + """Returns the conjugate of the quaternion.""" + q = self + return Quaternion(q.a, -q.b, -q.c, -q.d, norm=q._norm) + + def norm(self): + """Returns the norm of the quaternion.""" + if self._norm is None: # check if norm is pre-defined + q = self + # trigsimp is used to simplify sin(x)^2 + cos(x)^2 (these terms + # arise when from_axis_angle is used). + return sqrt(trigsimp(q.a**2 + q.b**2 + q.c**2 + q.d**2)) + + return self._norm + + def normalize(self): + """Returns the normalized form of the quaternion.""" + q = self + return q * (1/q.norm()) + + def inverse(self): + """Returns the inverse of the quaternion.""" + q = self + if not q.norm(): + raise ValueError("Cannot compute inverse for a quaternion with zero norm") + return conjugate(q) * (1/q.norm()**2) + + def pow(self, p): + """Finds the pth power of the quaternion. + + Parameters + ========== + + p : int + Power to be applied on quaternion. + + Returns + ======= + + Quaternion + Returns the p-th power of the current quaternion. + Returns the inverse if p = -1. + + Examples + ======== + + >>> from sympy import Quaternion + >>> q = Quaternion(1, 2, 3, 4) + >>> q.pow(4) + 668 + (-224)*i + (-336)*j + (-448)*k + + """ + try: + q, p = self, as_int(p) + except ValueError: + return NotImplemented + + if p < 0: + q, p = q.inverse(), -p + + if p == 1: + return q + + res = Quaternion(1, 0, 0, 0) + while p > 0: + if p & 1: + res *= q + q *= q + p >>= 1 + + return res + + def exp(self): + """Returns the exponential of $q$, given by $e^q$. + + Returns + ======= + + Quaternion + The exponential of the quaternion. + + Examples + ======== + + >>> from sympy import Quaternion + >>> q = Quaternion(1, 2, 3, 4) + >>> q.exp() + E*cos(sqrt(29)) + + 2*sqrt(29)*E*sin(sqrt(29))/29*i + + 3*sqrt(29)*E*sin(sqrt(29))/29*j + + 4*sqrt(29)*E*sin(sqrt(29))/29*k + + """ + # exp(q) = e^a(cos||v|| + v/||v||*sin||v||) + q = self + vector_norm = sqrt(q.b**2 + q.c**2 + q.d**2) + a = exp(q.a) * cos(vector_norm) + b = exp(q.a) * sin(vector_norm) * q.b / vector_norm + c = exp(q.a) * sin(vector_norm) * q.c / vector_norm + d = exp(q.a) * sin(vector_norm) * q.d / vector_norm + + return Quaternion(a, b, c, d) + + def log(self): + r"""Returns the logarithm of the quaternion, given by $\log q$. + + Examples + ======== + + >>> from sympy import Quaternion + >>> q = Quaternion(1, 2, 3, 4) + >>> q.log() + log(sqrt(30)) + + 2*sqrt(29)*acos(sqrt(30)/30)/29*i + + 3*sqrt(29)*acos(sqrt(30)/30)/29*j + + 4*sqrt(29)*acos(sqrt(30)/30)/29*k + + """ + # log(q) = log||q|| + v/||v||*arccos(a/||q||) + q = self + vector_norm = sqrt(q.b**2 + q.c**2 + q.d**2) + q_norm = q.norm() + a = ln(q_norm) + b = q.b * acos(q.a / q_norm) / vector_norm + c = q.c * acos(q.a / q_norm) / vector_norm + d = q.d * acos(q.a / q_norm) / vector_norm + + return Quaternion(a, b, c, d) + + def _eval_subs(self, *args): + elements = [i.subs(*args) for i in self.args] + norm = self._norm + if norm is not None: + norm = norm.subs(*args) + _check_norm(elements, norm) + return Quaternion(*elements, norm=norm) + + def _eval_evalf(self, prec): + """Returns the floating point approximations (decimal numbers) of the quaternion. + + Returns + ======= + + Quaternion + Floating point approximations of quaternion(self) + + Examples + ======== + + >>> from sympy import Quaternion + >>> from sympy import sqrt + >>> q = Quaternion(1/sqrt(1), 1/sqrt(2), 1/sqrt(3), 1/sqrt(4)) + >>> q.evalf() + 1.00000000000000 + + 0.707106781186547*i + + 0.577350269189626*j + + 0.500000000000000*k + + """ + nprec = prec_to_dps(prec) + return Quaternion(*[arg.evalf(n=nprec) for arg in self.args]) + + def pow_cos_sin(self, p): + """Computes the pth power in the cos-sin form. + + Parameters + ========== + + p : int + Power to be applied on quaternion. + + Returns + ======= + + Quaternion + The p-th power in the cos-sin form. + + Examples + ======== + + >>> from sympy import Quaternion + >>> q = Quaternion(1, 2, 3, 4) + >>> q.pow_cos_sin(4) + 900*cos(4*acos(sqrt(30)/30)) + + 1800*sqrt(29)*sin(4*acos(sqrt(30)/30))/29*i + + 2700*sqrt(29)*sin(4*acos(sqrt(30)/30))/29*j + + 3600*sqrt(29)*sin(4*acos(sqrt(30)/30))/29*k + + """ + # q = ||q||*(cos(a) + u*sin(a)) + # q^p = ||q||^p * (cos(p*a) + u*sin(p*a)) + + q = self + (v, angle) = q.to_axis_angle() + q2 = Quaternion.from_axis_angle(v, p * angle) + return q2 * (q.norm()**p) + + def integrate(self, *args): + """Computes integration of quaternion. + + Returns + ======= + + Quaternion + Integration of the quaternion(self) with the given variable. + + Examples + ======== + + Indefinite Integral of quaternion : + + >>> from sympy import Quaternion + >>> from sympy.abc import x + >>> q = Quaternion(1, 2, 3, 4) + >>> q.integrate(x) + x + 2*x*i + 3*x*j + 4*x*k + + Definite integral of quaternion : + + >>> from sympy import Quaternion + >>> from sympy.abc import x + >>> q = Quaternion(1, 2, 3, 4) + >>> q.integrate((x, 1, 5)) + 4 + 8*i + 12*j + 16*k + + """ + return Quaternion(integrate(self.a, *args), integrate(self.b, *args), + integrate(self.c, *args), integrate(self.d, *args)) + + @staticmethod + def rotate_point(pin, r): + """Returns the coordinates of the point pin (a 3 tuple) after rotation. + + Parameters + ========== + + pin : tuple + A 3-element tuple of coordinates of a point which needs to be + rotated. + r : Quaternion or tuple + Axis and angle of rotation. + + It's important to note that when r is a tuple, it must be of the form + (axis, angle) + + Returns + ======= + + tuple + The coordinates of the point after rotation. + + Examples + ======== + + >>> from sympy import Quaternion + >>> from sympy import symbols, trigsimp, cos, sin + >>> x = symbols('x') + >>> q = Quaternion(cos(x/2), 0, 0, sin(x/2)) + >>> trigsimp(Quaternion.rotate_point((1, 1, 1), q)) + (sqrt(2)*cos(x + pi/4), sqrt(2)*sin(x + pi/4), 1) + >>> (axis, angle) = q.to_axis_angle() + >>> trigsimp(Quaternion.rotate_point((1, 1, 1), (axis, angle))) + (sqrt(2)*cos(x + pi/4), sqrt(2)*sin(x + pi/4), 1) + + """ + if isinstance(r, tuple): + # if r is of the form (vector, angle) + q = Quaternion.from_axis_angle(r[0], r[1]) + else: + # if r is a quaternion + q = r.normalize() + pout = q * Quaternion(0, pin[0], pin[1], pin[2]) * conjugate(q) + return (pout.b, pout.c, pout.d) + + def to_axis_angle(self): + """Returns the axis and angle of rotation of a quaternion. + + Returns + ======= + + tuple + Tuple of (axis, angle) + + Examples + ======== + + >>> from sympy import Quaternion + >>> q = Quaternion(1, 1, 1, 1) + >>> (axis, angle) = q.to_axis_angle() + >>> axis + (sqrt(3)/3, sqrt(3)/3, sqrt(3)/3) + >>> angle + 2*pi/3 + + """ + q = self + if q.a.is_negative: + q = q * -1 + + q = q.normalize() + angle = trigsimp(2 * acos(q.a)) + + # Since quaternion is normalised, q.a is less than 1. + s = sqrt(1 - q.a*q.a) + + x = trigsimp(q.b / s) + y = trigsimp(q.c / s) + z = trigsimp(q.d / s) + + v = (x, y, z) + t = (v, angle) + + return t + + def to_rotation_matrix(self, v=None, homogeneous=True): + """Returns the equivalent rotation transformation matrix of the quaternion + which represents rotation about the origin if ``v`` is not passed. + + Parameters + ========== + + v : tuple or None + Default value: None + homogeneous : bool + When True, gives an expression that may be more efficient for + symbolic calculations but less so for direct evaluation. Both + formulas are mathematically equivalent. + Default value: True + + Returns + ======= + + tuple + Returns the equivalent rotation transformation matrix of the quaternion + which represents rotation about the origin if v is not passed. + + Examples + ======== + + >>> from sympy import Quaternion + >>> from sympy import symbols, trigsimp, cos, sin + >>> x = symbols('x') + >>> q = Quaternion(cos(x/2), 0, 0, sin(x/2)) + >>> trigsimp(q.to_rotation_matrix()) + Matrix([ + [cos(x), -sin(x), 0], + [sin(x), cos(x), 0], + [ 0, 0, 1]]) + + Generates a 4x4 transformation matrix (used for rotation about a point + other than the origin) if the point(v) is passed as an argument. + """ + + q = self + s = q.norm()**-2 + + # diagonal elements are different according to parameter normal + if homogeneous: + m00 = s*(q.a**2 + q.b**2 - q.c**2 - q.d**2) + m11 = s*(q.a**2 - q.b**2 + q.c**2 - q.d**2) + m22 = s*(q.a**2 - q.b**2 - q.c**2 + q.d**2) + else: + m00 = 1 - 2*s*(q.c**2 + q.d**2) + m11 = 1 - 2*s*(q.b**2 + q.d**2) + m22 = 1 - 2*s*(q.b**2 + q.c**2) + + m01 = 2*s*(q.b*q.c - q.d*q.a) + m02 = 2*s*(q.b*q.d + q.c*q.a) + + m10 = 2*s*(q.b*q.c + q.d*q.a) + m12 = 2*s*(q.c*q.d - q.b*q.a) + + m20 = 2*s*(q.b*q.d - q.c*q.a) + m21 = 2*s*(q.c*q.d + q.b*q.a) + + if not v: + return Matrix([[m00, m01, m02], [m10, m11, m12], [m20, m21, m22]]) + + else: + (x, y, z) = v + + m03 = x - x*m00 - y*m01 - z*m02 + m13 = y - x*m10 - y*m11 - z*m12 + m23 = z - x*m20 - y*m21 - z*m22 + m30 = m31 = m32 = 0 + m33 = 1 + + return Matrix([[m00, m01, m02, m03], [m10, m11, m12, m13], + [m20, m21, m22, m23], [m30, m31, m32, m33]]) + + def scalar_part(self): + r"""Returns scalar part($\mathbf{S}(q)$) of the quaternion q. + + Explanation + =========== + + Given a quaternion $q = a + bi + cj + dk$, returns $\mathbf{S}(q) = a$. + + Examples + ======== + + >>> from sympy.algebras.quaternion import Quaternion + >>> q = Quaternion(4, 8, 13, 12) + >>> q.scalar_part() + 4 + + """ + + return self.a + + def vector_part(self): + r""" + Returns $\mathbf{V}(q)$, the vector part of the quaternion $q$. + + Explanation + =========== + + Given a quaternion $q = a + bi + cj + dk$, returns $\mathbf{V}(q) = bi + cj + dk$. + + Examples + ======== + + >>> from sympy.algebras.quaternion import Quaternion + >>> q = Quaternion(1, 1, 1, 1) + >>> q.vector_part() + 0 + 1*i + 1*j + 1*k + + >>> q = Quaternion(4, 8, 13, 12) + >>> q.vector_part() + 0 + 8*i + 13*j + 12*k + + """ + + return Quaternion(0, self.b, self.c, self.d) + + def axis(self): + r""" + Returns $\mathbf{Ax}(q)$, the axis of the quaternion $q$. + + Explanation + =========== + + Given a quaternion $q = a + bi + cj + dk$, returns $\mathbf{Ax}(q)$ i.e., the versor of the vector part of that quaternion + equal to $\mathbf{U}[\mathbf{V}(q)]$. + The axis is always an imaginary unit with square equal to $-1 + 0i + 0j + 0k$. + + Examples + ======== + + >>> from sympy.algebras.quaternion import Quaternion + >>> q = Quaternion(1, 1, 1, 1) + >>> q.axis() + 0 + sqrt(3)/3*i + sqrt(3)/3*j + sqrt(3)/3*k + + See Also + ======== + + vector_part + + """ + axis = self.vector_part().normalize() + + return Quaternion(0, axis.b, axis.c, axis.d) + + def is_pure(self): + """ + Returns true if the quaternion is pure, false if the quaternion is not pure + or returns none if it is unknown. + + Explanation + =========== + + A pure quaternion (also a vector quaternion) is a quaternion with scalar + part equal to 0. + + Examples + ======== + + >>> from sympy.algebras.quaternion import Quaternion + >>> q = Quaternion(0, 8, 13, 12) + >>> q.is_pure() + True + + See Also + ======== + scalar_part + + """ + + return self.a.is_zero + + def is_zero_quaternion(self): + """ + Returns true if the quaternion is a zero quaternion or false if it is not a zero quaternion + and None if the value is unknown. + + Explanation + =========== + + A zero quaternion is a quaternion with both scalar part and + vector part equal to 0. + + Examples + ======== + + >>> from sympy.algebras.quaternion import Quaternion + >>> q = Quaternion(1, 0, 0, 0) + >>> q.is_zero_quaternion() + False + + >>> q = Quaternion(0, 0, 0, 0) + >>> q.is_zero_quaternion() + True + + See Also + ======== + scalar_part + vector_part + + """ + + return self.norm().is_zero + + def angle(self): + r""" + Returns the angle of the quaternion measured in the real-axis plane. + + Explanation + =========== + + Given a quaternion $q = a + bi + cj + dk$ where $a$, $b$, $c$ and $d$ + are real numbers, returns the angle of the quaternion given by + + .. math:: + \theta := 2 \operatorname{atan_2}\left(\sqrt{b^2 + c^2 + d^2}, {a}\right) + + Examples + ======== + + >>> from sympy.algebras.quaternion import Quaternion + >>> q = Quaternion(1, 4, 4, 4) + >>> q.angle() + 2*atan(4*sqrt(3)) + + """ + + return 2 * atan2(self.vector_part().norm(), self.scalar_part()) + + + def arc_coplanar(self, other): + """ + Returns True if the transformation arcs represented by the input quaternions happen in the same plane. + + Explanation + =========== + + Two quaternions are said to be coplanar (in this arc sense) when their axes are parallel. + The plane of a quaternion is the one normal to its axis. + + Parameters + ========== + + other : a Quaternion + + Returns + ======= + + True : if the planes of the two quaternions are the same, apart from its orientation/sign. + False : if the planes of the two quaternions are not the same, apart from its orientation/sign. + None : if plane of either of the quaternion is unknown. + + Examples + ======== + + >>> from sympy.algebras.quaternion import Quaternion + >>> q1 = Quaternion(1, 4, 4, 4) + >>> q2 = Quaternion(3, 8, 8, 8) + >>> Quaternion.arc_coplanar(q1, q2) + True + + >>> q1 = Quaternion(2, 8, 13, 12) + >>> Quaternion.arc_coplanar(q1, q2) + False + + See Also + ======== + + vector_coplanar + is_pure + + """ + if (self.is_zero_quaternion()) or (other.is_zero_quaternion()): + raise ValueError('Neither of the given quaternions can be 0') + + return fuzzy_or([(self.axis() - other.axis()).is_zero_quaternion(), (self.axis() + other.axis()).is_zero_quaternion()]) + + @classmethod + def vector_coplanar(cls, q1, q2, q3): + r""" + Returns True if the axis of the pure quaternions seen as 3D vectors + ``q1``, ``q2``, and ``q3`` are coplanar. + + Explanation + =========== + + Three pure quaternions are vector coplanar if the quaternions seen as 3D vectors are coplanar. + + Parameters + ========== + + q1 + A pure Quaternion. + q2 + A pure Quaternion. + q3 + A pure Quaternion. + + Returns + ======= + + True : if the axis of the pure quaternions seen as 3D vectors + q1, q2, and q3 are coplanar. + False : if the axis of the pure quaternions seen as 3D vectors + q1, q2, and q3 are not coplanar. + None : if the axis of the pure quaternions seen as 3D vectors + q1, q2, and q3 are coplanar is unknown. + + Examples + ======== + + >>> from sympy.algebras.quaternion import Quaternion + >>> q1 = Quaternion(0, 4, 4, 4) + >>> q2 = Quaternion(0, 8, 8, 8) + >>> q3 = Quaternion(0, 24, 24, 24) + >>> Quaternion.vector_coplanar(q1, q2, q3) + True + + >>> q1 = Quaternion(0, 8, 16, 8) + >>> q2 = Quaternion(0, 8, 3, 12) + >>> Quaternion.vector_coplanar(q1, q2, q3) + False + + See Also + ======== + + axis + is_pure + + """ + + if fuzzy_not(q1.is_pure()) or fuzzy_not(q2.is_pure()) or fuzzy_not(q3.is_pure()): + raise ValueError('The given quaternions must be pure') + + M = Matrix([[q1.b, q1.c, q1.d], [q2.b, q2.c, q2.d], [q3.b, q3.c, q3.d]]).det() + return M.is_zero + + def parallel(self, other): + """ + Returns True if the two pure quaternions seen as 3D vectors are parallel. + + Explanation + =========== + + Two pure quaternions are called parallel when their vector product is commutative which + implies that the quaternions seen as 3D vectors have same direction. + + Parameters + ========== + + other : a Quaternion + + Returns + ======= + + True : if the two pure quaternions seen as 3D vectors are parallel. + False : if the two pure quaternions seen as 3D vectors are not parallel. + None : if the two pure quaternions seen as 3D vectors are parallel is unknown. + + Examples + ======== + + >>> from sympy.algebras.quaternion import Quaternion + >>> q = Quaternion(0, 4, 4, 4) + >>> q1 = Quaternion(0, 8, 8, 8) + >>> q.parallel(q1) + True + + >>> q1 = Quaternion(0, 8, 13, 12) + >>> q.parallel(q1) + False + + """ + + if fuzzy_not(self.is_pure()) or fuzzy_not(other.is_pure()): + raise ValueError('The provided quaternions must be pure') + + return (self*other - other*self).is_zero_quaternion() + + def orthogonal(self, other): + """ + Returns the orthogonality of two quaternions. + + Explanation + =========== + + Two pure quaternions are called orthogonal when their product is anti-commutative. + + Parameters + ========== + + other : a Quaternion + + Returns + ======= + + True : if the two pure quaternions seen as 3D vectors are orthogonal. + False : if the two pure quaternions seen as 3D vectors are not orthogonal. + None : if the two pure quaternions seen as 3D vectors are orthogonal is unknown. + + Examples + ======== + + >>> from sympy.algebras.quaternion import Quaternion + >>> q = Quaternion(0, 4, 4, 4) + >>> q1 = Quaternion(0, 8, 8, 8) + >>> q.orthogonal(q1) + False + + >>> q1 = Quaternion(0, 2, 2, 0) + >>> q = Quaternion(0, 2, -2, 0) + >>> q.orthogonal(q1) + True + + """ + + if fuzzy_not(self.is_pure()) or fuzzy_not(other.is_pure()): + raise ValueError('The given quaternions must be pure') + + return (self*other + other*self).is_zero_quaternion() + + def index_vector(self): + r""" + Returns the index vector of the quaternion. + + Explanation + =========== + + The index vector is given by $\mathbf{T}(q)$, the norm (or magnitude) of + the quaternion $q$, multiplied by $\mathbf{Ax}(q)$, the axis of $q$. + + Returns + ======= + + Quaternion: representing index vector of the provided quaternion. + + Examples + ======== + + >>> from sympy.algebras.quaternion import Quaternion + >>> q = Quaternion(2, 4, 2, 4) + >>> q.index_vector() + 0 + 4*sqrt(10)/3*i + 2*sqrt(10)/3*j + 4*sqrt(10)/3*k + + See Also + ======== + + axis + norm + + """ + + return self.norm() * self.axis() + + def mensor(self): + """ + Returns the natural logarithm of the norm(magnitude) of the quaternion. + + Examples + ======== + + >>> from sympy.algebras.quaternion import Quaternion + >>> q = Quaternion(2, 4, 2, 4) + >>> q.mensor() + log(2*sqrt(10)) + >>> q.norm() + 2*sqrt(10) + + See Also + ======== + + norm + + """ + + return ln(self.norm()) diff --git a/.venv/lib/python3.13/site-packages/sympy/assumptions/__init__.py b/.venv/lib/python3.13/site-packages/sympy/assumptions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..af862653f3ce0eeb67f7764e16c32f3466e87024 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/assumptions/__init__.py @@ -0,0 +1,18 @@ +""" +A module to implement logical predicates and assumption system. +""" + +from .assume import ( + AppliedPredicate, Predicate, AssumptionsContext, assuming, + global_assumptions +) +from .ask import Q, ask, register_handler, remove_handler +from .refine import refine +from .relation import BinaryRelation, AppliedBinaryRelation + +__all__ = [ + 'AppliedPredicate', 'Predicate', 'AssumptionsContext', 'assuming', + 'global_assumptions', 'Q', 'ask', 'register_handler', 'remove_handler', + 'refine', + 'BinaryRelation', 'AppliedBinaryRelation' +] diff --git a/.venv/lib/python3.13/site-packages/sympy/assumptions/ask.py b/.venv/lib/python3.13/site-packages/sympy/assumptions/ask.py new file mode 100644 index 0000000000000000000000000000000000000000..ec81ec8ecce245c2a798cf9e71af1e9373292bc1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/assumptions/ask.py @@ -0,0 +1,651 @@ +"""Module for querying SymPy objects about assumptions.""" + +from sympy.assumptions.assume import (global_assumptions, Predicate, + AppliedPredicate) +from sympy.assumptions.cnf import CNF, EncodedCNF, Literal +from sympy.core import sympify +from sympy.core.kind import BooleanKind +from sympy.core.relational import Eq, Ne, Gt, Lt, Ge, Le +from sympy.logic.inference import satisfiable +from sympy.utilities.decorator import memoize_property +from sympy.utilities.exceptions import (sympy_deprecation_warning, + SymPyDeprecationWarning, + ignore_warnings) + + +# Memoization is necessary for the properties of AssumptionKeys to +# ensure that only one object of Predicate objects are created. +# This is because assumption handlers are registered on those objects. + + +class AssumptionKeys: + """ + This class contains all the supported keys by ``ask``. + It should be accessed via the instance ``sympy.Q``. + + """ + + # DO NOT add methods or properties other than predicate keys. + # SAT solver checks the properties of Q and use them to compute the + # fact system. Non-predicate attributes will break this. + + @memoize_property + def hermitian(self): + from .handlers.sets import HermitianPredicate + return HermitianPredicate() + + @memoize_property + def antihermitian(self): + from .handlers.sets import AntihermitianPredicate + return AntihermitianPredicate() + + @memoize_property + def real(self): + from .handlers.sets import RealPredicate + return RealPredicate() + + @memoize_property + def extended_real(self): + from .handlers.sets import ExtendedRealPredicate + return ExtendedRealPredicate() + + @memoize_property + def imaginary(self): + from .handlers.sets import ImaginaryPredicate + return ImaginaryPredicate() + + @memoize_property + def complex(self): + from .handlers.sets import ComplexPredicate + return ComplexPredicate() + + @memoize_property + def algebraic(self): + from .handlers.sets import AlgebraicPredicate + return AlgebraicPredicate() + + @memoize_property + def transcendental(self): + from .predicates.sets import TranscendentalPredicate + return TranscendentalPredicate() + + @memoize_property + def integer(self): + from .handlers.sets import IntegerPredicate + return IntegerPredicate() + + @memoize_property + def noninteger(self): + from .predicates.sets import NonIntegerPredicate + return NonIntegerPredicate() + + @memoize_property + def rational(self): + from .handlers.sets import RationalPredicate + return RationalPredicate() + + @memoize_property + def irrational(self): + from .handlers.sets import IrrationalPredicate + return IrrationalPredicate() + + @memoize_property + def finite(self): + from .handlers.calculus import FinitePredicate + return FinitePredicate() + + @memoize_property + def infinite(self): + from .handlers.calculus import InfinitePredicate + return InfinitePredicate() + + @memoize_property + def positive_infinite(self): + from .handlers.calculus import PositiveInfinitePredicate + return PositiveInfinitePredicate() + + @memoize_property + def negative_infinite(self): + from .handlers.calculus import NegativeInfinitePredicate + return NegativeInfinitePredicate() + + @memoize_property + def positive(self): + from .handlers.order import PositivePredicate + return PositivePredicate() + + @memoize_property + def negative(self): + from .handlers.order import NegativePredicate + return NegativePredicate() + + @memoize_property + def zero(self): + from .handlers.order import ZeroPredicate + return ZeroPredicate() + + @memoize_property + def extended_positive(self): + from .handlers.order import ExtendedPositivePredicate + return ExtendedPositivePredicate() + + @memoize_property + def extended_negative(self): + from .handlers.order import ExtendedNegativePredicate + return ExtendedNegativePredicate() + + @memoize_property + def nonzero(self): + from .handlers.order import NonZeroPredicate + return NonZeroPredicate() + + @memoize_property + def nonpositive(self): + from .handlers.order import NonPositivePredicate + return NonPositivePredicate() + + @memoize_property + def nonnegative(self): + from .handlers.order import NonNegativePredicate + return NonNegativePredicate() + + @memoize_property + def extended_nonzero(self): + from .handlers.order import ExtendedNonZeroPredicate + return ExtendedNonZeroPredicate() + + @memoize_property + def extended_nonpositive(self): + from .handlers.order import ExtendedNonPositivePredicate + return ExtendedNonPositivePredicate() + + @memoize_property + def extended_nonnegative(self): + from .handlers.order import ExtendedNonNegativePredicate + return ExtendedNonNegativePredicate() + + @memoize_property + def even(self): + from .handlers.ntheory import EvenPredicate + return EvenPredicate() + + @memoize_property + def odd(self): + from .handlers.ntheory import OddPredicate + return OddPredicate() + + @memoize_property + def prime(self): + from .handlers.ntheory import PrimePredicate + return PrimePredicate() + + @memoize_property + def composite(self): + from .handlers.ntheory import CompositePredicate + return CompositePredicate() + + @memoize_property + def commutative(self): + from .handlers.common import CommutativePredicate + return CommutativePredicate() + + @memoize_property + def is_true(self): + from .handlers.common import IsTruePredicate + return IsTruePredicate() + + @memoize_property + def symmetric(self): + from .handlers.matrices import SymmetricPredicate + return SymmetricPredicate() + + @memoize_property + def invertible(self): + from .handlers.matrices import InvertiblePredicate + return InvertiblePredicate() + + @memoize_property + def orthogonal(self): + from .handlers.matrices import OrthogonalPredicate + return OrthogonalPredicate() + + @memoize_property + def unitary(self): + from .handlers.matrices import UnitaryPredicate + return UnitaryPredicate() + + @memoize_property + def positive_definite(self): + from .handlers.matrices import PositiveDefinitePredicate + return PositiveDefinitePredicate() + + @memoize_property + def upper_triangular(self): + from .handlers.matrices import UpperTriangularPredicate + return UpperTriangularPredicate() + + @memoize_property + def lower_triangular(self): + from .handlers.matrices import LowerTriangularPredicate + return LowerTriangularPredicate() + + @memoize_property + def diagonal(self): + from .handlers.matrices import DiagonalPredicate + return DiagonalPredicate() + + @memoize_property + def fullrank(self): + from .handlers.matrices import FullRankPredicate + return FullRankPredicate() + + @memoize_property + def square(self): + from .handlers.matrices import SquarePredicate + return SquarePredicate() + + @memoize_property + def integer_elements(self): + from .handlers.matrices import IntegerElementsPredicate + return IntegerElementsPredicate() + + @memoize_property + def real_elements(self): + from .handlers.matrices import RealElementsPredicate + return RealElementsPredicate() + + @memoize_property + def complex_elements(self): + from .handlers.matrices import ComplexElementsPredicate + return ComplexElementsPredicate() + + @memoize_property + def singular(self): + from .predicates.matrices import SingularPredicate + return SingularPredicate() + + @memoize_property + def normal(self): + from .predicates.matrices import NormalPredicate + return NormalPredicate() + + @memoize_property + def triangular(self): + from .predicates.matrices import TriangularPredicate + return TriangularPredicate() + + @memoize_property + def unit_triangular(self): + from .predicates.matrices import UnitTriangularPredicate + return UnitTriangularPredicate() + + @memoize_property + def eq(self): + from .relation.equality import EqualityPredicate + return EqualityPredicate() + + @memoize_property + def ne(self): + from .relation.equality import UnequalityPredicate + return UnequalityPredicate() + + @memoize_property + def gt(self): + from .relation.equality import StrictGreaterThanPredicate + return StrictGreaterThanPredicate() + + @memoize_property + def ge(self): + from .relation.equality import GreaterThanPredicate + return GreaterThanPredicate() + + @memoize_property + def lt(self): + from .relation.equality import StrictLessThanPredicate + return StrictLessThanPredicate() + + @memoize_property + def le(self): + from .relation.equality import LessThanPredicate + return LessThanPredicate() + + +Q = AssumptionKeys() + +def _extract_all_facts(assump, exprs): + """ + Extract all relevant assumptions from *assump* with respect to given *exprs*. + + Parameters + ========== + + assump : sympy.assumptions.cnf.CNF + + exprs : tuple of expressions + + Returns + ======= + + sympy.assumptions.cnf.CNF + + Examples + ======== + + >>> from sympy import Q + >>> from sympy.assumptions.cnf import CNF + >>> from sympy.assumptions.ask import _extract_all_facts + >>> from sympy.abc import x, y + >>> assump = CNF.from_prop(Q.positive(x) & Q.integer(y)) + >>> exprs = (x,) + >>> cnf = _extract_all_facts(assump, exprs) + >>> cnf.clauses + {frozenset({Literal(Q.positive, False)})} + + """ + facts = set() + + for clause in assump.clauses: + args = [] + for literal in clause: + if isinstance(literal.lit, AppliedPredicate) and len(literal.lit.arguments) == 1: + if literal.lit.arg in exprs: + # Add literal if it has matching in it + args.append(Literal(literal.lit.function, literal.is_Not)) + else: + # If any of the literals doesn't have matching expr don't add the whole clause. + break + else: + # If any of the literals aren't unary predicate don't add the whole clause. + break + + else: + if args: + facts.add(frozenset(args)) + return CNF(facts) + + +def ask(proposition, assumptions=True, context=global_assumptions): + """ + Function to evaluate the proposition with assumptions. + + Explanation + =========== + + This function evaluates the proposition to ``True`` or ``False`` if + the truth value can be determined. If not, it returns ``None``. + + It should be discerned from :func:`~.refine` which, when applied to a + proposition, simplifies the argument to symbolic ``Boolean`` instead of + Python built-in ``True``, ``False`` or ``None``. + + **Syntax** + + * ask(proposition) + Evaluate the *proposition* in global assumption context. + + * ask(proposition, assumptions) + Evaluate the *proposition* with respect to *assumptions* in + global assumption context. + + Parameters + ========== + + proposition : Boolean + Proposition which will be evaluated to boolean value. If this is + not ``AppliedPredicate``, it will be wrapped by ``Q.is_true``. + + assumptions : Boolean, optional + Local assumptions to evaluate the *proposition*. + + context : AssumptionsContext, optional + Default assumptions to evaluate the *proposition*. By default, + this is ``sympy.assumptions.global_assumptions`` variable. + + Returns + ======= + + ``True``, ``False``, or ``None`` + + Raises + ====== + + TypeError : *proposition* or *assumptions* is not valid logical expression. + + ValueError : assumptions are inconsistent. + + Examples + ======== + + >>> from sympy import ask, Q, pi + >>> from sympy.abc import x, y + >>> ask(Q.rational(pi)) + False + >>> ask(Q.even(x*y), Q.even(x) & Q.integer(y)) + True + >>> ask(Q.prime(4*x), Q.integer(x)) + False + + If the truth value cannot be determined, ``None`` will be returned. + + >>> print(ask(Q.odd(3*x))) # cannot determine unless we know x + None + + ``ValueError`` is raised if assumptions are inconsistent. + + >>> ask(Q.integer(x), Q.even(x) & Q.odd(x)) + Traceback (most recent call last): + ... + ValueError: inconsistent assumptions Q.even(x) & Q.odd(x) + + Notes + ===== + + Relations in assumptions are not implemented (yet), so the following + will not give a meaningful result. + + >>> ask(Q.positive(x), x > 0) + + It is however a work in progress. + + See Also + ======== + + sympy.assumptions.refine.refine : Simplification using assumptions. + Proposition is not reduced to ``None`` if the truth value cannot + be determined. + """ + from sympy.assumptions.satask import satask + from sympy.assumptions.lra_satask import lra_satask + from sympy.logic.algorithms.lra_theory import UnhandledInput + + proposition = sympify(proposition) + assumptions = sympify(assumptions) + + if isinstance(proposition, Predicate) or proposition.kind is not BooleanKind: + raise TypeError("proposition must be a valid logical expression") + + if isinstance(assumptions, Predicate) or assumptions.kind is not BooleanKind: + raise TypeError("assumptions must be a valid logical expression") + + binrelpreds = {Eq: Q.eq, Ne: Q.ne, Gt: Q.gt, Lt: Q.lt, Ge: Q.ge, Le: Q.le} + if isinstance(proposition, AppliedPredicate): + key, args = proposition.function, proposition.arguments + elif proposition.func in binrelpreds: + key, args = binrelpreds[type(proposition)], proposition.args + else: + key, args = Q.is_true, (proposition,) + + # convert local and global assumptions to CNF + assump_cnf = CNF.from_prop(assumptions) + assump_cnf.extend(context) + + # extract the relevant facts from assumptions with respect to args + local_facts = _extract_all_facts(assump_cnf, args) + + # convert default facts and assumed facts to encoded CNF + known_facts_cnf = get_all_known_facts() + enc_cnf = EncodedCNF() + enc_cnf.from_cnf(CNF(known_facts_cnf)) + enc_cnf.add_from_cnf(local_facts) + + # check the satisfiability of given assumptions + if local_facts.clauses and satisfiable(enc_cnf) is False: + raise ValueError("inconsistent assumptions %s" % assumptions) + + # quick computation for single fact + res = _ask_single_fact(key, local_facts) + if res is not None: + return res + + # direct resolution method, no logic + res = key(*args)._eval_ask(assumptions) + if res is not None: + return bool(res) + + # using satask (still costly) + res = satask(proposition, assumptions=assumptions, context=context) + if res is not None: + return res + + try: + res = lra_satask(proposition, assumptions=assumptions, context=context) + except UnhandledInput: + return None + + return res + + +def _ask_single_fact(key, local_facts): + """ + Compute the truth value of single predicate using assumptions. + + Parameters + ========== + + key : sympy.assumptions.assume.Predicate + Proposition predicate. + + local_facts : sympy.assumptions.cnf.CNF + Local assumption in CNF form. + + Returns + ======= + + ``True``, ``False`` or ``None`` + + Examples + ======== + + >>> from sympy import Q + >>> from sympy.assumptions.cnf import CNF + >>> from sympy.assumptions.ask import _ask_single_fact + + If prerequisite of proposition is rejected by the assumption, + return ``False``. + + >>> key, assump = Q.zero, ~Q.zero + >>> local_facts = CNF.from_prop(assump) + >>> _ask_single_fact(key, local_facts) + False + >>> key, assump = Q.zero, ~Q.even + >>> local_facts = CNF.from_prop(assump) + >>> _ask_single_fact(key, local_facts) + False + + If assumption implies the proposition, return ``True``. + + >>> key, assump = Q.even, Q.zero + >>> local_facts = CNF.from_prop(assump) + >>> _ask_single_fact(key, local_facts) + True + + If proposition rejects the assumption, return ``False``. + + >>> key, assump = Q.even, Q.odd + >>> local_facts = CNF.from_prop(assump) + >>> _ask_single_fact(key, local_facts) + False + """ + if local_facts.clauses: + + known_facts_dict = get_known_facts_dict() + + if len(local_facts.clauses) == 1: + cl, = local_facts.clauses + if len(cl) == 1: + f, = cl + prop_facts = known_facts_dict.get(key, None) + prop_req = prop_facts[0] if prop_facts is not None else set() + if f.is_Not and f.arg in prop_req: + # the prerequisite of proposition is rejected + return False + + for clause in local_facts.clauses: + if len(clause) == 1: + f, = clause + prop_facts = known_facts_dict.get(f.arg, None) if not f.is_Not else None + if prop_facts is None: + continue + + prop_req, prop_rej = prop_facts + if key in prop_req: + # assumption implies the proposition + return True + elif key in prop_rej: + # proposition rejects the assumption + return False + + return None + + +def register_handler(key, handler): + """ + Register a handler in the ask system. key must be a string and handler a + class inheriting from AskHandler. + + .. deprecated:: 1.8. + Use multipledispatch handler instead. See :obj:`~.Predicate`. + + """ + sympy_deprecation_warning( + """ + The AskHandler system is deprecated. The register_handler() function + should be replaced with the multipledispatch handler of Predicate. + """, + deprecated_since_version="1.8", + active_deprecations_target='deprecated-askhandler', + ) + if isinstance(key, Predicate): + key = key.name.name + Qkey = getattr(Q, key, None) + if Qkey is not None: + Qkey.add_handler(handler) + else: + setattr(Q, key, Predicate(key, handlers=[handler])) + + +def remove_handler(key, handler): + """ + Removes a handler from the ask system. + + .. deprecated:: 1.8. + Use multipledispatch handler instead. See :obj:`~.Predicate`. + + """ + sympy_deprecation_warning( + """ + The AskHandler system is deprecated. The remove_handler() function + should be replaced with the multipledispatch handler of Predicate. + """, + deprecated_since_version="1.8", + active_deprecations_target='deprecated-askhandler', + ) + if isinstance(key, Predicate): + key = key.name.name + # Don't show the same warning again recursively + with ignore_warnings(SymPyDeprecationWarning): + getattr(Q, key).remove_handler(handler) + + +from sympy.assumptions.ask_generated import (get_all_known_facts, + get_known_facts_dict) diff --git a/.venv/lib/python3.13/site-packages/sympy/assumptions/ask_generated.py b/.venv/lib/python3.13/site-packages/sympy/assumptions/ask_generated.py new file mode 100644 index 0000000000000000000000000000000000000000..d90cdffc1e127d78e18f70cda13d8d5e0530d41b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/assumptions/ask_generated.py @@ -0,0 +1,352 @@ +""" +Do NOT manually edit this file. +Instead, run ./bin/ask_update.py. +""" + +from sympy.assumptions.ask import Q +from sympy.assumptions.cnf import Literal +from sympy.core.cache import cacheit + +@cacheit +def get_all_known_facts(): + """ + Known facts between unary predicates as CNF clauses. + """ + return { + frozenset((Literal(Q.algebraic, False), Literal(Q.imaginary, True), Literal(Q.transcendental, False))), + frozenset((Literal(Q.algebraic, False), Literal(Q.negative, True), Literal(Q.transcendental, False))), + frozenset((Literal(Q.algebraic, False), Literal(Q.positive, True), Literal(Q.transcendental, False))), + frozenset((Literal(Q.algebraic, False), Literal(Q.rational, True))), + frozenset((Literal(Q.algebraic, False), Literal(Q.transcendental, False), Literal(Q.zero, True))), + frozenset((Literal(Q.algebraic, True), Literal(Q.finite, False))), + frozenset((Literal(Q.algebraic, True), Literal(Q.transcendental, True))), + frozenset((Literal(Q.antihermitian, False), Literal(Q.hermitian, False), Literal(Q.zero, True))), + frozenset((Literal(Q.antihermitian, False), Literal(Q.imaginary, True))), + frozenset((Literal(Q.commutative, False), Literal(Q.finite, True))), + frozenset((Literal(Q.commutative, False), Literal(Q.infinite, True))), + frozenset((Literal(Q.complex_elements, False), Literal(Q.real_elements, True))), + frozenset((Literal(Q.composite, False), Literal(Q.even, True), Literal(Q.positive, True), Literal(Q.prime, False))), + frozenset((Literal(Q.composite, True), Literal(Q.even, False), Literal(Q.odd, False))), + frozenset((Literal(Q.composite, True), Literal(Q.positive, False))), + frozenset((Literal(Q.composite, True), Literal(Q.prime, True))), + frozenset((Literal(Q.diagonal, False), Literal(Q.lower_triangular, True), Literal(Q.upper_triangular, True))), + frozenset((Literal(Q.diagonal, True), Literal(Q.lower_triangular, False))), + frozenset((Literal(Q.diagonal, True), Literal(Q.normal, False))), + frozenset((Literal(Q.diagonal, True), Literal(Q.symmetric, False))), + frozenset((Literal(Q.diagonal, True), Literal(Q.upper_triangular, False))), + frozenset((Literal(Q.even, False), Literal(Q.odd, False), Literal(Q.prime, True))), + frozenset((Literal(Q.even, False), Literal(Q.zero, True))), + frozenset((Literal(Q.even, True), Literal(Q.odd, True))), + frozenset((Literal(Q.even, True), Literal(Q.rational, False))), + frozenset((Literal(Q.finite, False), Literal(Q.transcendental, True))), + frozenset((Literal(Q.finite, True), Literal(Q.infinite, True))), + frozenset((Literal(Q.fullrank, False), Literal(Q.invertible, True))), + frozenset((Literal(Q.fullrank, True), Literal(Q.invertible, False), Literal(Q.square, True))), + frozenset((Literal(Q.hermitian, False), Literal(Q.negative, True))), + frozenset((Literal(Q.hermitian, False), Literal(Q.positive, True))), + frozenset((Literal(Q.hermitian, False), Literal(Q.zero, True))), + frozenset((Literal(Q.imaginary, True), Literal(Q.negative, True))), + frozenset((Literal(Q.imaginary, True), Literal(Q.positive, True))), + frozenset((Literal(Q.imaginary, True), Literal(Q.zero, True))), + frozenset((Literal(Q.infinite, False), Literal(Q.negative_infinite, True))), + frozenset((Literal(Q.infinite, False), Literal(Q.positive_infinite, True))), + frozenset((Literal(Q.integer_elements, True), Literal(Q.real_elements, False))), + frozenset((Literal(Q.invertible, False), Literal(Q.positive_definite, True))), + frozenset((Literal(Q.invertible, False), Literal(Q.singular, False))), + frozenset((Literal(Q.invertible, False), Literal(Q.unitary, True))), + frozenset((Literal(Q.invertible, True), Literal(Q.singular, True))), + frozenset((Literal(Q.invertible, True), Literal(Q.square, False))), + frozenset((Literal(Q.irrational, False), Literal(Q.negative, True), Literal(Q.rational, False))), + frozenset((Literal(Q.irrational, False), Literal(Q.positive, True), Literal(Q.rational, False))), + frozenset((Literal(Q.irrational, False), Literal(Q.rational, False), Literal(Q.zero, True))), + frozenset((Literal(Q.irrational, True), Literal(Q.negative, False), Literal(Q.positive, False), Literal(Q.zero, False))), + frozenset((Literal(Q.irrational, True), Literal(Q.rational, True))), + frozenset((Literal(Q.lower_triangular, False), Literal(Q.triangular, True), Literal(Q.upper_triangular, False))), + frozenset((Literal(Q.lower_triangular, True), Literal(Q.triangular, False))), + frozenset((Literal(Q.negative, False), Literal(Q.positive, False), Literal(Q.rational, True), Literal(Q.zero, False))), + frozenset((Literal(Q.negative, True), Literal(Q.negative_infinite, True))), + frozenset((Literal(Q.negative, True), Literal(Q.positive, True))), + frozenset((Literal(Q.negative, True), Literal(Q.positive_infinite, True))), + frozenset((Literal(Q.negative, True), Literal(Q.zero, True))), + frozenset((Literal(Q.negative_infinite, True), Literal(Q.positive, True))), + frozenset((Literal(Q.negative_infinite, True), Literal(Q.positive_infinite, True))), + frozenset((Literal(Q.negative_infinite, True), Literal(Q.zero, True))), + frozenset((Literal(Q.normal, False), Literal(Q.unitary, True))), + frozenset((Literal(Q.normal, True), Literal(Q.square, False))), + frozenset((Literal(Q.odd, True), Literal(Q.rational, False))), + frozenset((Literal(Q.orthogonal, False), Literal(Q.real_elements, True), Literal(Q.unitary, True))), + frozenset((Literal(Q.orthogonal, True), Literal(Q.positive_definite, False))), + frozenset((Literal(Q.orthogonal, True), Literal(Q.unitary, False))), + frozenset((Literal(Q.positive, False), Literal(Q.prime, True))), + frozenset((Literal(Q.positive, True), Literal(Q.positive_infinite, True))), + frozenset((Literal(Q.positive, True), Literal(Q.zero, True))), + frozenset((Literal(Q.positive_infinite, True), Literal(Q.zero, True))), + frozenset((Literal(Q.square, False), Literal(Q.symmetric, True))), + frozenset((Literal(Q.triangular, False), Literal(Q.unit_triangular, True))), + frozenset((Literal(Q.triangular, False), Literal(Q.upper_triangular, True))) + } + +@cacheit +def get_all_known_matrix_facts(): + """ + Known facts between unary predicates for matrices as CNF clauses. + """ + return { + frozenset((Literal(Q.complex_elements, False), Literal(Q.real_elements, True))), + frozenset((Literal(Q.diagonal, False), Literal(Q.lower_triangular, True), Literal(Q.upper_triangular, True))), + frozenset((Literal(Q.diagonal, True), Literal(Q.lower_triangular, False))), + frozenset((Literal(Q.diagonal, True), Literal(Q.normal, False))), + frozenset((Literal(Q.diagonal, True), Literal(Q.symmetric, False))), + frozenset((Literal(Q.diagonal, True), Literal(Q.upper_triangular, False))), + frozenset((Literal(Q.fullrank, False), Literal(Q.invertible, True))), + frozenset((Literal(Q.fullrank, True), Literal(Q.invertible, False), Literal(Q.square, True))), + frozenset((Literal(Q.integer_elements, True), Literal(Q.real_elements, False))), + frozenset((Literal(Q.invertible, False), Literal(Q.positive_definite, True))), + frozenset((Literal(Q.invertible, False), Literal(Q.singular, False))), + frozenset((Literal(Q.invertible, False), Literal(Q.unitary, True))), + frozenset((Literal(Q.invertible, True), Literal(Q.singular, True))), + frozenset((Literal(Q.invertible, True), Literal(Q.square, False))), + frozenset((Literal(Q.lower_triangular, False), Literal(Q.triangular, True), Literal(Q.upper_triangular, False))), + frozenset((Literal(Q.lower_triangular, True), Literal(Q.triangular, False))), + frozenset((Literal(Q.normal, False), Literal(Q.unitary, True))), + frozenset((Literal(Q.normal, True), Literal(Q.square, False))), + frozenset((Literal(Q.orthogonal, False), Literal(Q.real_elements, True), Literal(Q.unitary, True))), + frozenset((Literal(Q.orthogonal, True), Literal(Q.positive_definite, False))), + frozenset((Literal(Q.orthogonal, True), Literal(Q.unitary, False))), + frozenset((Literal(Q.square, False), Literal(Q.symmetric, True))), + frozenset((Literal(Q.triangular, False), Literal(Q.unit_triangular, True))), + frozenset((Literal(Q.triangular, False), Literal(Q.upper_triangular, True))) + } + +@cacheit +def get_all_known_number_facts(): + """ + Known facts between unary predicates for numbers as CNF clauses. + """ + return { + frozenset((Literal(Q.algebraic, False), Literal(Q.imaginary, True), Literal(Q.transcendental, False))), + frozenset((Literal(Q.algebraic, False), Literal(Q.negative, True), Literal(Q.transcendental, False))), + frozenset((Literal(Q.algebraic, False), Literal(Q.positive, True), Literal(Q.transcendental, False))), + frozenset((Literal(Q.algebraic, False), Literal(Q.rational, True))), + frozenset((Literal(Q.algebraic, False), Literal(Q.transcendental, False), Literal(Q.zero, True))), + frozenset((Literal(Q.algebraic, True), Literal(Q.finite, False))), + frozenset((Literal(Q.algebraic, True), Literal(Q.transcendental, True))), + frozenset((Literal(Q.antihermitian, False), Literal(Q.hermitian, False), Literal(Q.zero, True))), + frozenset((Literal(Q.antihermitian, False), Literal(Q.imaginary, True))), + frozenset((Literal(Q.commutative, False), Literal(Q.finite, True))), + frozenset((Literal(Q.commutative, False), Literal(Q.infinite, True))), + frozenset((Literal(Q.composite, False), Literal(Q.even, True), Literal(Q.positive, True), Literal(Q.prime, False))), + frozenset((Literal(Q.composite, True), Literal(Q.even, False), Literal(Q.odd, False))), + frozenset((Literal(Q.composite, True), Literal(Q.positive, False))), + frozenset((Literal(Q.composite, True), Literal(Q.prime, True))), + frozenset((Literal(Q.even, False), Literal(Q.odd, False), Literal(Q.prime, True))), + frozenset((Literal(Q.even, False), Literal(Q.zero, True))), + frozenset((Literal(Q.even, True), Literal(Q.odd, True))), + frozenset((Literal(Q.even, True), Literal(Q.rational, False))), + frozenset((Literal(Q.finite, False), Literal(Q.transcendental, True))), + frozenset((Literal(Q.finite, True), Literal(Q.infinite, True))), + frozenset((Literal(Q.hermitian, False), Literal(Q.negative, True))), + frozenset((Literal(Q.hermitian, False), Literal(Q.positive, True))), + frozenset((Literal(Q.hermitian, False), Literal(Q.zero, True))), + frozenset((Literal(Q.imaginary, True), Literal(Q.negative, True))), + frozenset((Literal(Q.imaginary, True), Literal(Q.positive, True))), + frozenset((Literal(Q.imaginary, True), Literal(Q.zero, True))), + frozenset((Literal(Q.infinite, False), Literal(Q.negative_infinite, True))), + frozenset((Literal(Q.infinite, False), Literal(Q.positive_infinite, True))), + frozenset((Literal(Q.irrational, False), Literal(Q.negative, True), Literal(Q.rational, False))), + frozenset((Literal(Q.irrational, False), Literal(Q.positive, True), Literal(Q.rational, False))), + frozenset((Literal(Q.irrational, False), Literal(Q.rational, False), Literal(Q.zero, True))), + frozenset((Literal(Q.irrational, True), Literal(Q.negative, False), Literal(Q.positive, False), Literal(Q.zero, False))), + frozenset((Literal(Q.irrational, True), Literal(Q.rational, True))), + frozenset((Literal(Q.negative, False), Literal(Q.positive, False), Literal(Q.rational, True), Literal(Q.zero, False))), + frozenset((Literal(Q.negative, True), Literal(Q.negative_infinite, True))), + frozenset((Literal(Q.negative, True), Literal(Q.positive, True))), + frozenset((Literal(Q.negative, True), Literal(Q.positive_infinite, True))), + frozenset((Literal(Q.negative, True), Literal(Q.zero, True))), + frozenset((Literal(Q.negative_infinite, True), Literal(Q.positive, True))), + frozenset((Literal(Q.negative_infinite, True), Literal(Q.positive_infinite, True))), + frozenset((Literal(Q.negative_infinite, True), Literal(Q.zero, True))), + frozenset((Literal(Q.odd, True), Literal(Q.rational, False))), + frozenset((Literal(Q.positive, False), Literal(Q.prime, True))), + frozenset((Literal(Q.positive, True), Literal(Q.positive_infinite, True))), + frozenset((Literal(Q.positive, True), Literal(Q.zero, True))), + frozenset((Literal(Q.positive_infinite, True), Literal(Q.zero, True))) + } + +@cacheit +def get_known_facts_dict(): + """ + Logical relations between unary predicates as dictionary. + + Each key is a predicate, and item is two groups of predicates. + First group contains the predicates which are implied by the key, and + second group contains the predicates which are rejected by the key. + + """ + return { + Q.algebraic: (set([Q.algebraic, Q.commutative, Q.complex, Q.finite]), + set([Q.infinite, Q.negative_infinite, Q.positive_infinite, + Q.transcendental])), + Q.antihermitian: (set([Q.antihermitian]), set([])), + Q.commutative: (set([Q.commutative]), set([])), + Q.complex: (set([Q.commutative, Q.complex, Q.finite]), + set([Q.infinite, Q.negative_infinite, Q.positive_infinite])), + Q.complex_elements: (set([Q.complex_elements]), set([])), + Q.composite: (set([Q.algebraic, Q.commutative, Q.complex, Q.composite, + Q.extended_nonnegative, Q.extended_nonzero, + Q.extended_positive, Q.extended_real, Q.finite, Q.hermitian, + Q.integer, Q.nonnegative, Q.nonzero, Q.positive, Q.rational, + Q.real]), set([Q.extended_negative, Q.extended_nonpositive, + Q.imaginary, Q.infinite, Q.irrational, Q.negative, + Q.negative_infinite, Q.nonpositive, Q.positive_infinite, + Q.prime, Q.transcendental, Q.zero])), + Q.diagonal: (set([Q.diagonal, Q.lower_triangular, Q.normal, Q.square, + Q.symmetric, Q.triangular, Q.upper_triangular]), set([])), + Q.even: (set([Q.algebraic, Q.commutative, Q.complex, Q.even, + Q.extended_real, Q.finite, Q.hermitian, Q.integer, Q.rational, + Q.real]), set([Q.imaginary, Q.infinite, Q.irrational, + Q.negative_infinite, Q.odd, Q.positive_infinite, + Q.transcendental])), + Q.extended_negative: (set([Q.commutative, Q.extended_negative, + Q.extended_nonpositive, Q.extended_nonzero, Q.extended_real]), + set([Q.composite, Q.extended_nonnegative, Q.extended_positive, + Q.imaginary, Q.nonnegative, Q.positive, Q.positive_infinite, + Q.prime, Q.zero])), + Q.extended_nonnegative: (set([Q.commutative, Q.extended_nonnegative, + Q.extended_real]), set([Q.extended_negative, Q.imaginary, + Q.negative, Q.negative_infinite])), + Q.extended_nonpositive: (set([Q.commutative, Q.extended_nonpositive, + Q.extended_real]), set([Q.composite, Q.extended_positive, + Q.imaginary, Q.positive, Q.positive_infinite, Q.prime])), + Q.extended_nonzero: (set([Q.commutative, Q.extended_nonzero, + Q.extended_real]), set([Q.imaginary, Q.zero])), + Q.extended_positive: (set([Q.commutative, Q.extended_nonnegative, + Q.extended_nonzero, Q.extended_positive, Q.extended_real]), + set([Q.extended_negative, Q.extended_nonpositive, Q.imaginary, + Q.negative, Q.negative_infinite, Q.nonpositive, Q.zero])), + Q.extended_real: (set([Q.commutative, Q.extended_real]), + set([Q.imaginary])), + Q.finite: (set([Q.commutative, Q.finite]), set([Q.infinite, + Q.negative_infinite, Q.positive_infinite])), + Q.fullrank: (set([Q.fullrank]), set([])), + Q.hermitian: (set([Q.hermitian]), set([])), + Q.imaginary: (set([Q.antihermitian, Q.commutative, Q.complex, + Q.finite, Q.imaginary]), set([Q.composite, Q.even, + Q.extended_negative, Q.extended_nonnegative, + Q.extended_nonpositive, Q.extended_nonzero, + Q.extended_positive, Q.extended_real, Q.infinite, Q.integer, + Q.irrational, Q.negative, Q.negative_infinite, Q.nonnegative, + Q.nonpositive, Q.nonzero, Q.odd, Q.positive, + Q.positive_infinite, Q.prime, Q.rational, Q.real, Q.zero])), + Q.infinite: (set([Q.commutative, Q.infinite]), set([Q.algebraic, + Q.complex, Q.composite, Q.even, Q.finite, Q.imaginary, + Q.integer, Q.irrational, Q.negative, Q.nonnegative, + Q.nonpositive, Q.nonzero, Q.odd, Q.positive, Q.prime, + Q.rational, Q.real, Q.transcendental, Q.zero])), + Q.integer: (set([Q.algebraic, Q.commutative, Q.complex, + Q.extended_real, Q.finite, Q.hermitian, Q.integer, Q.rational, + Q.real]), set([Q.imaginary, Q.infinite, Q.irrational, + Q.negative_infinite, Q.positive_infinite, Q.transcendental])), + Q.integer_elements: (set([Q.complex_elements, Q.integer_elements, + Q.real_elements]), set([])), + Q.invertible: (set([Q.fullrank, Q.invertible, Q.square]), + set([Q.singular])), + Q.irrational: (set([Q.commutative, Q.complex, Q.extended_nonzero, + Q.extended_real, Q.finite, Q.hermitian, Q.irrational, + Q.nonzero, Q.real]), set([Q.composite, Q.even, Q.imaginary, + Q.infinite, Q.integer, Q.negative_infinite, Q.odd, + Q.positive_infinite, Q.prime, Q.rational, Q.zero])), + Q.is_true: (set([Q.is_true]), set([])), + Q.lower_triangular: (set([Q.lower_triangular, Q.triangular]), set([])), + Q.negative: (set([Q.commutative, Q.complex, Q.extended_negative, + Q.extended_nonpositive, Q.extended_nonzero, Q.extended_real, + Q.finite, Q.hermitian, Q.negative, Q.nonpositive, Q.nonzero, + Q.real]), set([Q.composite, Q.extended_nonnegative, + Q.extended_positive, Q.imaginary, Q.infinite, + Q.negative_infinite, Q.nonnegative, Q.positive, + Q.positive_infinite, Q.prime, Q.zero])), + Q.negative_infinite: (set([Q.commutative, Q.extended_negative, + Q.extended_nonpositive, Q.extended_nonzero, Q.extended_real, + Q.infinite, Q.negative_infinite]), set([Q.algebraic, + Q.complex, Q.composite, Q.even, Q.extended_nonnegative, + Q.extended_positive, Q.finite, Q.imaginary, Q.integer, + Q.irrational, Q.negative, Q.nonnegative, Q.nonpositive, + Q.nonzero, Q.odd, Q.positive, Q.positive_infinite, Q.prime, + Q.rational, Q.real, Q.transcendental, Q.zero])), + Q.noninteger: (set([Q.noninteger]), set([])), + Q.nonnegative: (set([Q.commutative, Q.complex, Q.extended_nonnegative, + Q.extended_real, Q.finite, Q.hermitian, Q.nonnegative, + Q.real]), set([Q.extended_negative, Q.imaginary, Q.infinite, + Q.negative, Q.negative_infinite, Q.positive_infinite])), + Q.nonpositive: (set([Q.commutative, Q.complex, Q.extended_nonpositive, + Q.extended_real, Q.finite, Q.hermitian, Q.nonpositive, + Q.real]), set([Q.composite, Q.extended_positive, Q.imaginary, + Q.infinite, Q.negative_infinite, Q.positive, + Q.positive_infinite, Q.prime])), + Q.nonzero: (set([Q.commutative, Q.complex, Q.extended_nonzero, + Q.extended_real, Q.finite, Q.hermitian, Q.nonzero, Q.real]), + set([Q.imaginary, Q.infinite, Q.negative_infinite, + Q.positive_infinite, Q.zero])), + Q.normal: (set([Q.normal, Q.square]), set([])), + Q.odd: (set([Q.algebraic, Q.commutative, Q.complex, + Q.extended_nonzero, Q.extended_real, Q.finite, Q.hermitian, + Q.integer, Q.nonzero, Q.odd, Q.rational, Q.real]), + set([Q.even, Q.imaginary, Q.infinite, Q.irrational, + Q.negative_infinite, Q.positive_infinite, Q.transcendental, + Q.zero])), + Q.orthogonal: (set([Q.fullrank, Q.invertible, Q.normal, Q.orthogonal, + Q.positive_definite, Q.square, Q.unitary]), set([Q.singular])), + Q.positive: (set([Q.commutative, Q.complex, Q.extended_nonnegative, + Q.extended_nonzero, Q.extended_positive, Q.extended_real, + Q.finite, Q.hermitian, Q.nonnegative, Q.nonzero, Q.positive, + Q.real]), set([Q.extended_negative, Q.extended_nonpositive, + Q.imaginary, Q.infinite, Q.negative, Q.negative_infinite, + Q.nonpositive, Q.positive_infinite, Q.zero])), + Q.positive_definite: (set([Q.fullrank, Q.invertible, + Q.positive_definite, Q.square]), set([Q.singular])), + Q.positive_infinite: (set([Q.commutative, Q.extended_nonnegative, + Q.extended_nonzero, Q.extended_positive, Q.extended_real, + Q.infinite, Q.positive_infinite]), set([Q.algebraic, + Q.complex, Q.composite, Q.even, Q.extended_negative, + Q.extended_nonpositive, Q.finite, Q.imaginary, Q.integer, + Q.irrational, Q.negative, Q.negative_infinite, Q.nonnegative, + Q.nonpositive, Q.nonzero, Q.odd, Q.positive, Q.prime, + Q.rational, Q.real, Q.transcendental, Q.zero])), + Q.prime: (set([Q.algebraic, Q.commutative, Q.complex, + Q.extended_nonnegative, Q.extended_nonzero, + Q.extended_positive, Q.extended_real, Q.finite, Q.hermitian, + Q.integer, Q.nonnegative, Q.nonzero, Q.positive, Q.prime, + Q.rational, Q.real]), set([Q.composite, Q.extended_negative, + Q.extended_nonpositive, Q.imaginary, Q.infinite, Q.irrational, + Q.negative, Q.negative_infinite, Q.nonpositive, + Q.positive_infinite, Q.transcendental, Q.zero])), + Q.rational: (set([Q.algebraic, Q.commutative, Q.complex, + Q.extended_real, Q.finite, Q.hermitian, Q.rational, Q.real]), + set([Q.imaginary, Q.infinite, Q.irrational, + Q.negative_infinite, Q.positive_infinite, Q.transcendental])), + Q.real: (set([Q.commutative, Q.complex, Q.extended_real, Q.finite, + Q.hermitian, Q.real]), set([Q.imaginary, Q.infinite, + Q.negative_infinite, Q.positive_infinite])), + Q.real_elements: (set([Q.complex_elements, Q.real_elements]), set([])), + Q.singular: (set([Q.singular]), set([Q.invertible, Q.orthogonal, + Q.positive_definite, Q.unitary])), + Q.square: (set([Q.square]), set([])), + Q.symmetric: (set([Q.square, Q.symmetric]), set([])), + Q.transcendental: (set([Q.commutative, Q.complex, Q.finite, + Q.transcendental]), set([Q.algebraic, Q.composite, Q.even, + Q.infinite, Q.integer, Q.negative_infinite, Q.odd, + Q.positive_infinite, Q.prime, Q.rational, Q.zero])), + Q.triangular: (set([Q.triangular]), set([])), + Q.unit_triangular: (set([Q.triangular, Q.unit_triangular]), set([])), + Q.unitary: (set([Q.fullrank, Q.invertible, Q.normal, Q.square, + Q.unitary]), set([Q.singular])), + Q.upper_triangular: (set([Q.triangular, Q.upper_triangular]), set([])), + Q.zero: (set([Q.algebraic, Q.commutative, Q.complex, Q.even, + Q.extended_nonnegative, Q.extended_nonpositive, + Q.extended_real, Q.finite, Q.hermitian, Q.integer, + Q.nonnegative, Q.nonpositive, Q.rational, Q.real, Q.zero]), + set([Q.composite, Q.extended_negative, Q.extended_nonzero, + Q.extended_positive, Q.imaginary, Q.infinite, Q.irrational, + Q.negative, Q.negative_infinite, Q.nonzero, Q.odd, Q.positive, + Q.positive_infinite, Q.prime, Q.transcendental])), + } diff --git a/.venv/lib/python3.13/site-packages/sympy/assumptions/assume.py b/.venv/lib/python3.13/site-packages/sympy/assumptions/assume.py new file mode 100644 index 0000000000000000000000000000000000000000..743195a865a1d39389d471b95728ca79834ed019 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/assumptions/assume.py @@ -0,0 +1,485 @@ +"""A module which implements predicates and assumption context.""" + +from contextlib import contextmanager +import inspect +from sympy.core.symbol import Str +from sympy.core.sympify import _sympify +from sympy.logic.boolalg import Boolean, false, true +from sympy.multipledispatch.dispatcher import Dispatcher, str_signature +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.iterables import is_sequence +from sympy.utilities.source import get_class + + +class AssumptionsContext(set): + """ + Set containing default assumptions which are applied to the ``ask()`` + function. + + Explanation + =========== + + This is used to represent global assumptions, but you can also use this + class to create your own local assumptions contexts. It is basically a thin + wrapper to Python's set, so see its documentation for advanced usage. + + Examples + ======== + + The default assumption context is ``global_assumptions``, which is initially empty: + + >>> from sympy import ask, Q + >>> from sympy.assumptions import global_assumptions + >>> global_assumptions + AssumptionsContext() + + You can add default assumptions: + + >>> from sympy.abc import x + >>> global_assumptions.add(Q.real(x)) + >>> global_assumptions + AssumptionsContext({Q.real(x)}) + >>> ask(Q.real(x)) + True + + And remove them: + + >>> global_assumptions.remove(Q.real(x)) + >>> print(ask(Q.real(x))) + None + + The ``clear()`` method removes every assumption: + + >>> global_assumptions.add(Q.positive(x)) + >>> global_assumptions + AssumptionsContext({Q.positive(x)}) + >>> global_assumptions.clear() + >>> global_assumptions + AssumptionsContext() + + See Also + ======== + + assuming + + """ + + def add(self, *assumptions): + """Add assumptions.""" + for a in assumptions: + super().add(a) + + def _sympystr(self, printer): + if not self: + return "%s()" % self.__class__.__name__ + return "{}({})".format(self.__class__.__name__, printer._print_set(self)) + +global_assumptions = AssumptionsContext() + + +class AppliedPredicate(Boolean): + """ + The class of expressions resulting from applying ``Predicate`` to + the arguments. ``AppliedPredicate`` merely wraps its argument and + remain unevaluated. To evaluate it, use the ``ask()`` function. + + Examples + ======== + + >>> from sympy import Q, ask + >>> Q.integer(1) + Q.integer(1) + + The ``function`` attribute returns the predicate, and the ``arguments`` + attribute returns the tuple of arguments. + + >>> type(Q.integer(1)) + + >>> Q.integer(1).function + Q.integer + >>> Q.integer(1).arguments + (1,) + + Applied predicates can be evaluated to a boolean value with ``ask``: + + >>> ask(Q.integer(1)) + True + + """ + __slots__ = () + + def __new__(cls, predicate, *args): + if not isinstance(predicate, Predicate): + raise TypeError("%s is not a Predicate." % predicate) + args = map(_sympify, args) + return super().__new__(cls, predicate, *args) + + @property + def arg(self): + """ + Return the expression used by this assumption. + + Examples + ======== + + >>> from sympy import Q, Symbol + >>> x = Symbol('x') + >>> a = Q.integer(x + 1) + >>> a.arg + x + 1 + + """ + # Will be deprecated + args = self._args + if len(args) == 2: + # backwards compatibility + return args[1] + raise TypeError("'arg' property is allowed only for unary predicates.") + + @property + def function(self): + """ + Return the predicate. + """ + # Will be changed to self.args[0] after args overriding is removed + return self._args[0] + + @property + def arguments(self): + """ + Return the arguments which are applied to the predicate. + """ + # Will be changed to self.args[1:] after args overriding is removed + return self._args[1:] + + def _eval_ask(self, assumptions): + return self.function.eval(self.arguments, assumptions) + + @property + def binary_symbols(self): + from .ask import Q + if self.function == Q.is_true: + i = self.arguments[0] + if i.is_Boolean or i.is_Symbol: + return i.binary_symbols + if self.function in (Q.eq, Q.ne): + if true in self.arguments or false in self.arguments: + if self.arguments[0].is_Symbol: + return {self.arguments[0]} + elif self.arguments[1].is_Symbol: + return {self.arguments[1]} + return set() + + +class PredicateMeta(type): + def __new__(cls, clsname, bases, dct): + # If handler is not defined, assign empty dispatcher. + if "handler" not in dct: + name = f"Ask{clsname.capitalize()}Handler" + handler = Dispatcher(name, doc="Handler for key %s" % name) + dct["handler"] = handler + + dct["_orig_doc"] = dct.get("__doc__", "") + + return super().__new__(cls, clsname, bases, dct) + + @property + def __doc__(cls): + handler = cls.handler + doc = cls._orig_doc + if cls is not Predicate and handler is not None: + doc += "Handler\n" + doc += " =======\n\n" + + # Append the handler's doc without breaking sphinx documentation. + docs = [" Multiply dispatched method: %s" % handler.name] + if handler.doc: + for line in handler.doc.splitlines(): + if not line: + continue + docs.append(" %s" % line) + other = [] + for sig in handler.ordering[::-1]: + func = handler.funcs[sig] + if func.__doc__: + s = ' Inputs: <%s>' % str_signature(sig) + lines = [] + for line in func.__doc__.splitlines(): + lines.append(" %s" % line) + s += "\n".join(lines) + docs.append(s) + else: + other.append(str_signature(sig)) + if other: + othersig = " Other signatures:" + for line in other: + othersig += "\n * %s" % line + docs.append(othersig) + + doc += '\n\n'.join(docs) + + return doc + + +class Predicate(Boolean, metaclass=PredicateMeta): + """ + Base class for mathematical predicates. It also serves as a + constructor for undefined predicate objects. + + Explanation + =========== + + Predicate is a function that returns a boolean value [1]. + + Predicate function is object, and it is instance of predicate class. + When a predicate is applied to arguments, ``AppliedPredicate`` + instance is returned. This merely wraps the argument and remain + unevaluated. To obtain the truth value of applied predicate, use the + function ``ask``. + + Evaluation of predicate is done by multiple dispatching. You can + register new handler to the predicate to support new types. + + Every predicate in SymPy can be accessed via the property of ``Q``. + For example, ``Q.even`` returns the predicate which checks if the + argument is even number. + + To define a predicate which can be evaluated, you must subclass this + class, make an instance of it, and register it to ``Q``. After then, + dispatch the handler by argument types. + + If you directly construct predicate using this class, you will get + ``UndefinedPredicate`` which cannot be dispatched. This is useful + when you are building boolean expressions which do not need to be + evaluated. + + Examples + ======== + + Applying and evaluating to boolean value: + + >>> from sympy import Q, ask + >>> ask(Q.prime(7)) + True + + You can define a new predicate by subclassing and dispatching. Here, + we define a predicate for sexy primes [2] as an example. + + >>> from sympy import Predicate, Integer + >>> class SexyPrimePredicate(Predicate): + ... name = "sexyprime" + >>> Q.sexyprime = SexyPrimePredicate() + >>> @Q.sexyprime.register(Integer, Integer) + ... def _(int1, int2, assumptions): + ... args = sorted([int1, int2]) + ... if not all(ask(Q.prime(a), assumptions) for a in args): + ... return False + ... return args[1] - args[0] == 6 + >>> ask(Q.sexyprime(5, 11)) + True + + Direct constructing returns ``UndefinedPredicate``, which can be + applied but cannot be dispatched. + + >>> from sympy import Predicate, Integer + >>> Q.P = Predicate("P") + >>> type(Q.P) + + >>> Q.P(1) + Q.P(1) + >>> Q.P.register(Integer)(lambda expr, assump: True) + Traceback (most recent call last): + ... + TypeError: cannot be dispatched. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Predicate_%28mathematical_logic%29 + .. [2] https://en.wikipedia.org/wiki/Sexy_prime + + """ + + is_Atom = True + + def __new__(cls, *args, **kwargs): + if cls is Predicate: + return UndefinedPredicate(*args, **kwargs) + obj = super().__new__(cls, *args) + return obj + + @property + def name(self): + # May be overridden + return type(self).__name__ + + @classmethod + def register(cls, *types, **kwargs): + """ + Register the signature to the handler. + """ + if cls.handler is None: + raise TypeError("%s cannot be dispatched." % type(cls)) + return cls.handler.register(*types, **kwargs) + + @classmethod + def register_many(cls, *types, **kwargs): + """ + Register multiple signatures to same handler. + """ + def _(func): + for t in types: + if not is_sequence(t): + t = (t,) # for convenience, allow passing `type` to mean `(type,)` + cls.register(*t, **kwargs)(func) + return _ + + def __call__(self, *args): + return AppliedPredicate(self, *args) + + def eval(self, args, assumptions=True): + """ + Evaluate ``self(*args)`` under the given assumptions. + + This uses only direct resolution methods, not logical inference. + """ + result = None + try: + result = self.handler(*args, assumptions=assumptions) + except NotImplementedError: + pass + return result + + def _eval_refine(self, assumptions): + # When Predicate is no longer Boolean, delete this method + return self + + +class UndefinedPredicate(Predicate): + """ + Predicate without handler. + + Explanation + =========== + + This predicate is generated by using ``Predicate`` directly for + construction. It does not have a handler, and evaluating this with + arguments is done by SAT solver. + + Examples + ======== + + >>> from sympy import Predicate, Q + >>> Q.P = Predicate('P') + >>> Q.P.func + + >>> Q.P.name + Str('P') + + """ + + handler = None + + def __new__(cls, name, handlers=None): + # "handlers" parameter supports old design + if not isinstance(name, Str): + name = Str(name) + obj = super(Boolean, cls).__new__(cls, name) + obj.handlers = handlers or [] + return obj + + @property + def name(self): + return self.args[0] + + def _hashable_content(self): + return (self.name,) + + def __getnewargs__(self): + return (self.name,) + + def __call__(self, expr): + return AppliedPredicate(self, expr) + + def add_handler(self, handler): + sympy_deprecation_warning( + """ + The AskHandler system is deprecated. Predicate.add_handler() + should be replaced with the multipledispatch handler of Predicate. + """, + deprecated_since_version="1.8", + active_deprecations_target='deprecated-askhandler', + ) + self.handlers.append(handler) + + def remove_handler(self, handler): + sympy_deprecation_warning( + """ + The AskHandler system is deprecated. Predicate.remove_handler() + should be replaced with the multipledispatch handler of Predicate. + """, + deprecated_since_version="1.8", + active_deprecations_target='deprecated-askhandler', + ) + self.handlers.remove(handler) + + def eval(self, args, assumptions=True): + # Support for deprecated design + # When old design is removed, this will always return None + sympy_deprecation_warning( + """ + The AskHandler system is deprecated. Evaluating UndefinedPredicate + objects should be replaced with the multipledispatch handler of + Predicate. + """, + deprecated_since_version="1.8", + active_deprecations_target='deprecated-askhandler', + stacklevel=5, + ) + expr, = args + res, _res = None, None + mro = inspect.getmro(type(expr)) + for handler in self.handlers: + cls = get_class(handler) + for subclass in mro: + eval_ = getattr(cls, subclass.__name__, None) + if eval_ is None: + continue + res = eval_(expr, assumptions) + # Do not stop if value returned is None + # Try to check for higher classes + if res is None: + continue + if _res is None: + _res = res + else: + # only check consistency if both resolutors have concluded + if _res != res: + raise ValueError('incompatible resolutors') + break + return res + + +@contextmanager +def assuming(*assumptions): + """ + Context manager for assumptions. + + Examples + ======== + + >>> from sympy import assuming, Q, ask + >>> from sympy.abc import x, y + >>> print(ask(Q.integer(x + y))) + None + >>> with assuming(Q.integer(x), Q.integer(y)): + ... print(ask(Q.integer(x + y))) + True + """ + old_global_assumptions = global_assumptions.copy() + global_assumptions.update(assumptions) + try: + yield + finally: + global_assumptions.clear() + global_assumptions.update(old_global_assumptions) diff --git a/.venv/lib/python3.13/site-packages/sympy/assumptions/cnf.py b/.venv/lib/python3.13/site-packages/sympy/assumptions/cnf.py new file mode 100644 index 0000000000000000000000000000000000000000..a95d27bed6eeb64c42f4edd9d49bd8e5753069e5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/assumptions/cnf.py @@ -0,0 +1,445 @@ +""" +The classes used here are for the internal use of assumptions system +only and should not be used anywhere else as these do not possess the +signatures common to SymPy objects. For general use of logic constructs +please refer to sympy.logic classes And, Or, Not, etc. +""" +from itertools import combinations, product, zip_longest +from sympy.assumptions.assume import AppliedPredicate, Predicate +from sympy.core.relational import Eq, Ne, Gt, Lt, Ge, Le +from sympy.core.singleton import S +from sympy.logic.boolalg import Or, And, Not, Xnor +from sympy.logic.boolalg import (Equivalent, ITE, Implies, Nand, Nor, Xor) + + +class Literal: + """ + The smallest element of a CNF object. + + Parameters + ========== + + lit : Boolean expression + + is_Not : bool + + Examples + ======== + + >>> from sympy import Q + >>> from sympy.assumptions.cnf import Literal + >>> from sympy.abc import x + >>> Literal(Q.even(x)) + Literal(Q.even(x), False) + >>> Literal(~Q.even(x)) + Literal(Q.even(x), True) + """ + + def __new__(cls, lit, is_Not=False): + if isinstance(lit, Not): + lit = lit.args[0] + is_Not = True + elif isinstance(lit, (AND, OR, Literal)): + return ~lit if is_Not else lit + obj = super().__new__(cls) + obj.lit = lit + obj.is_Not = is_Not + return obj + + @property + def arg(self): + return self.lit + + def rcall(self, expr): + if callable(self.lit): + lit = self.lit(expr) + else: + lit = self.lit.apply(expr) + return type(self)(lit, self.is_Not) + + def __invert__(self): + is_Not = not self.is_Not + return Literal(self.lit, is_Not) + + def __str__(self): + return '{}({}, {})'.format(type(self).__name__, self.lit, self.is_Not) + + __repr__ = __str__ + + def __eq__(self, other): + return self.arg == other.arg and self.is_Not == other.is_Not + + def __hash__(self): + h = hash((type(self).__name__, self.arg, self.is_Not)) + return h + + +class OR: + """ + A low-level implementation for Or + """ + def __init__(self, *args): + self._args = args + + @property + def args(self): + return sorted(self._args, key=str) + + def rcall(self, expr): + return type(self)(*[arg.rcall(expr) + for arg in self._args + ]) + + def __invert__(self): + return AND(*[~arg for arg in self._args]) + + def __hash__(self): + return hash((type(self).__name__,) + tuple(self.args)) + + def __eq__(self, other): + return self.args == other.args + + def __str__(self): + s = '(' + ' | '.join([str(arg) for arg in self.args]) + ')' + return s + + __repr__ = __str__ + + +class AND: + """ + A low-level implementation for And + """ + def __init__(self, *args): + self._args = args + + def __invert__(self): + return OR(*[~arg for arg in self._args]) + + @property + def args(self): + return sorted(self._args, key=str) + + def rcall(self, expr): + return type(self)(*[arg.rcall(expr) + for arg in self._args + ]) + + def __hash__(self): + return hash((type(self).__name__,) + tuple(self.args)) + + def __eq__(self, other): + return self.args == other.args + + def __str__(self): + s = '('+' & '.join([str(arg) for arg in self.args])+')' + return s + + __repr__ = __str__ + + +def to_NNF(expr, composite_map=None): + """ + Generates the Negation Normal Form of any boolean expression in terms + of AND, OR, and Literal objects. + + Examples + ======== + + >>> from sympy import Q, Eq + >>> from sympy.assumptions.cnf import to_NNF + >>> from sympy.abc import x, y + >>> expr = Q.even(x) & ~Q.positive(x) + >>> to_NNF(expr) + (Literal(Q.even(x), False) & Literal(Q.positive(x), True)) + + Supported boolean objects are converted to corresponding predicates. + + >>> to_NNF(Eq(x, y)) + Literal(Q.eq(x, y), False) + + If ``composite_map`` argument is given, ``to_NNF`` decomposes the + specified predicate into a combination of primitive predicates. + + >>> cmap = {Q.nonpositive: Q.negative | Q.zero} + >>> to_NNF(Q.nonpositive, cmap) + (Literal(Q.negative, False) | Literal(Q.zero, False)) + >>> to_NNF(Q.nonpositive(x), cmap) + (Literal(Q.negative(x), False) | Literal(Q.zero(x), False)) + """ + from sympy.assumptions.ask import Q + + if composite_map is None: + composite_map = {} + + + binrelpreds = {Eq: Q.eq, Ne: Q.ne, Gt: Q.gt, Lt: Q.lt, Ge: Q.ge, Le: Q.le} + if type(expr) in binrelpreds: + pred = binrelpreds[type(expr)] + expr = pred(*expr.args) + + if isinstance(expr, Not): + arg = expr.args[0] + tmp = to_NNF(arg, composite_map) # Strategy: negate the NNF of expr + return ~tmp + + if isinstance(expr, Or): + return OR(*[to_NNF(x, composite_map) for x in Or.make_args(expr)]) + + if isinstance(expr, And): + return AND(*[to_NNF(x, composite_map) for x in And.make_args(expr)]) + + if isinstance(expr, Nand): + tmp = AND(*[to_NNF(x, composite_map) for x in expr.args]) + return ~tmp + + if isinstance(expr, Nor): + tmp = OR(*[to_NNF(x, composite_map) for x in expr.args]) + return ~tmp + + if isinstance(expr, Xor): + cnfs = [] + for i in range(0, len(expr.args) + 1, 2): + for neg in combinations(expr.args, i): + clause = [~to_NNF(s, composite_map) if s in neg else to_NNF(s, composite_map) + for s in expr.args] + cnfs.append(OR(*clause)) + return AND(*cnfs) + + if isinstance(expr, Xnor): + cnfs = [] + for i in range(0, len(expr.args) + 1, 2): + for neg in combinations(expr.args, i): + clause = [~to_NNF(s, composite_map) if s in neg else to_NNF(s, composite_map) + for s in expr.args] + cnfs.append(OR(*clause)) + return ~AND(*cnfs) + + if isinstance(expr, Implies): + L, R = to_NNF(expr.args[0], composite_map), to_NNF(expr.args[1], composite_map) + return OR(~L, R) + + if isinstance(expr, Equivalent): + cnfs = [] + for a, b in zip_longest(expr.args, expr.args[1:], fillvalue=expr.args[0]): + a = to_NNF(a, composite_map) + b = to_NNF(b, composite_map) + cnfs.append(OR(~a, b)) + return AND(*cnfs) + + if isinstance(expr, ITE): + L = to_NNF(expr.args[0], composite_map) + M = to_NNF(expr.args[1], composite_map) + R = to_NNF(expr.args[2], composite_map) + return AND(OR(~L, M), OR(L, R)) + + if isinstance(expr, AppliedPredicate): + pred, args = expr.function, expr.arguments + newpred = composite_map.get(pred, None) + if newpred is not None: + return to_NNF(newpred.rcall(*args), composite_map) + + if isinstance(expr, Predicate): + newpred = composite_map.get(expr, None) + if newpred is not None: + return to_NNF(newpred, composite_map) + + return Literal(expr) + + +def distribute_AND_over_OR(expr): + """ + Distributes AND over OR in the NNF expression. + Returns the result( Conjunctive Normal Form of expression) + as a CNF object. + """ + if not isinstance(expr, (AND, OR)): + tmp = set() + tmp.add(frozenset((expr,))) + return CNF(tmp) + + if isinstance(expr, OR): + return CNF.all_or(*[distribute_AND_over_OR(arg) + for arg in expr._args]) + + if isinstance(expr, AND): + return CNF.all_and(*[distribute_AND_over_OR(arg) + for arg in expr._args]) + + +class CNF: + """ + Class to represent CNF of a Boolean expression. + Consists of set of clauses, which themselves are stored as + frozenset of Literal objects. + + Examples + ======== + + >>> from sympy import Q + >>> from sympy.assumptions.cnf import CNF + >>> from sympy.abc import x + >>> cnf = CNF.from_prop(Q.real(x) & ~Q.zero(x)) + >>> cnf.clauses + {frozenset({Literal(Q.zero(x), True)}), + frozenset({Literal(Q.negative(x), False), + Literal(Q.positive(x), False), Literal(Q.zero(x), False)})} + """ + def __init__(self, clauses=None): + if not clauses: + clauses = set() + self.clauses = clauses + + def add(self, prop): + clauses = CNF.to_CNF(prop).clauses + self.add_clauses(clauses) + + def __str__(self): + s = ' & '.join( + ['(' + ' | '.join([str(lit) for lit in clause]) +')' + for clause in self.clauses] + ) + return s + + def extend(self, props): + for p in props: + self.add(p) + return self + + def copy(self): + return CNF(set(self.clauses)) + + def add_clauses(self, clauses): + self.clauses |= clauses + + @classmethod + def from_prop(cls, prop): + res = cls() + res.add(prop) + return res + + def __iand__(self, other): + self.add_clauses(other.clauses) + return self + + def all_predicates(self): + predicates = set() + for c in self.clauses: + predicates |= {arg.lit for arg in c} + return predicates + + def _or(self, cnf): + clauses = set() + for a, b in product(self.clauses, cnf.clauses): + tmp = set(a) + tmp.update(b) + clauses.add(frozenset(tmp)) + return CNF(clauses) + + def _and(self, cnf): + clauses = self.clauses.union(cnf.clauses) + return CNF(clauses) + + def _not(self): + clss = list(self.clauses) + ll = {frozenset((~x,)) for x in clss[-1]} + ll = CNF(ll) + + for rest in clss[:-1]: + p = {frozenset((~x,)) for x in rest} + ll = ll._or(CNF(p)) + return ll + + def rcall(self, expr): + clause_list = [] + for clause in self.clauses: + lits = [arg.rcall(expr) for arg in clause] + clause_list.append(OR(*lits)) + expr = AND(*clause_list) + return distribute_AND_over_OR(expr) + + @classmethod + def all_or(cls, *cnfs): + b = cnfs[0].copy() + for rest in cnfs[1:]: + b = b._or(rest) + return b + + @classmethod + def all_and(cls, *cnfs): + b = cnfs[0].copy() + for rest in cnfs[1:]: + b = b._and(rest) + return b + + @classmethod + def to_CNF(cls, expr): + from sympy.assumptions.facts import get_composite_predicates + expr = to_NNF(expr, get_composite_predicates()) + expr = distribute_AND_over_OR(expr) + return expr + + @classmethod + def CNF_to_cnf(cls, cnf): + """ + Converts CNF object to SymPy's boolean expression + retaining the form of expression. + """ + def remove_literal(arg): + return Not(arg.lit) if arg.is_Not else arg.lit + + return And(*(Or(*(remove_literal(arg) for arg in clause)) for clause in cnf.clauses)) + + +class EncodedCNF: + """ + Class for encoding the CNF expression. + """ + def __init__(self, data=None, encoding=None): + if not data and not encoding: + data = [] + encoding = {} + self.data = data + self.encoding = encoding + self._symbols = list(encoding.keys()) + + def from_cnf(self, cnf): + self._symbols = list(cnf.all_predicates()) + n = len(self._symbols) + self.encoding = dict(zip(self._symbols, range(1, n + 1))) + self.data = [self.encode(clause) for clause in cnf.clauses] + + @property + def symbols(self): + return self._symbols + + @property + def variables(self): + return range(1, len(self._symbols) + 1) + + def copy(self): + new_data = [set(clause) for clause in self.data] + return EncodedCNF(new_data, dict(self.encoding)) + + def add_prop(self, prop): + cnf = CNF.from_prop(prop) + self.add_from_cnf(cnf) + + def add_from_cnf(self, cnf): + clauses = [self.encode(clause) for clause in cnf.clauses] + self.data += clauses + + def encode_arg(self, arg): + literal = arg.lit + value = self.encoding.get(literal, None) + if value is None: + n = len(self._symbols) + self._symbols.append(literal) + value = self.encoding[literal] = n + 1 + if arg.is_Not: + return -value + else: + return value + + def encode(self, clause): + return {self.encode_arg(arg) if not arg.lit == S.false else 0 for arg in clause} diff --git a/.venv/lib/python3.13/site-packages/sympy/assumptions/facts.py b/.venv/lib/python3.13/site-packages/sympy/assumptions/facts.py new file mode 100644 index 0000000000000000000000000000000000000000..2ff268677cf74e252ac6c3bc3eecbea08b9414d0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/assumptions/facts.py @@ -0,0 +1,270 @@ +""" +Known facts in assumptions module. + +This module defines the facts between unary predicates in ``get_known_facts()``, +and supports functions to generate the contents in +``sympy.assumptions.ask_generated`` file. +""" + +from sympy.assumptions.ask import Q +from sympy.assumptions.assume import AppliedPredicate +from sympy.core.cache import cacheit +from sympy.core.symbol import Symbol +from sympy.logic.boolalg import (to_cnf, And, Not, Implies, Equivalent, + Exclusive,) +from sympy.logic.inference import satisfiable + + +@cacheit +def get_composite_predicates(): + # To reduce the complexity of sat solver, these predicates are + # transformed into the combination of primitive predicates. + return { + Q.real : Q.negative | Q.zero | Q.positive, + Q.integer : Q.even | Q.odd, + Q.nonpositive : Q.negative | Q.zero, + Q.nonzero : Q.negative | Q.positive, + Q.nonnegative : Q.zero | Q.positive, + Q.extended_real : Q.negative_infinite | Q.negative | Q.zero | Q.positive | Q.positive_infinite, + Q.extended_positive: Q.positive | Q.positive_infinite, + Q.extended_negative: Q.negative | Q.negative_infinite, + Q.extended_nonzero: Q.negative_infinite | Q.negative | Q.positive | Q.positive_infinite, + Q.extended_nonpositive: Q.negative_infinite | Q.negative | Q.zero, + Q.extended_nonnegative: Q.zero | Q.positive | Q.positive_infinite, + Q.complex : Q.algebraic | Q.transcendental + } + + +@cacheit +def get_known_facts(x=None): + """ + Facts between unary predicates. + + Parameters + ========== + + x : Symbol, optional + Placeholder symbol for unary facts. Default is ``Symbol('x')``. + + Returns + ======= + + fact : Known facts in conjugated normal form. + + """ + if x is None: + x = Symbol('x') + + fact = And( + get_number_facts(x), + get_matrix_facts(x) + ) + return fact + + +@cacheit +def get_number_facts(x = None): + """ + Facts between unary number predicates. + + Parameters + ========== + + x : Symbol, optional + Placeholder symbol for unary facts. Default is ``Symbol('x')``. + + Returns + ======= + + fact : Known facts in conjugated normal form. + + """ + if x is None: + x = Symbol('x') + + fact = And( + # primitive predicates for extended real exclude each other. + Exclusive(Q.negative_infinite(x), Q.negative(x), Q.zero(x), + Q.positive(x), Q.positive_infinite(x)), + + # build complex plane + Exclusive(Q.real(x), Q.imaginary(x)), + Implies(Q.real(x) | Q.imaginary(x), Q.complex(x)), + + # other subsets of complex + Exclusive(Q.transcendental(x), Q.algebraic(x)), + Equivalent(Q.real(x), Q.rational(x) | Q.irrational(x)), + Exclusive(Q.irrational(x), Q.rational(x)), + Implies(Q.rational(x), Q.algebraic(x)), + + # integers + Exclusive(Q.even(x), Q.odd(x)), + Implies(Q.integer(x), Q.rational(x)), + Implies(Q.zero(x), Q.even(x)), + Exclusive(Q.composite(x), Q.prime(x)), + Implies(Q.composite(x) | Q.prime(x), Q.integer(x) & Q.positive(x)), + Implies(Q.even(x) & Q.positive(x) & ~Q.prime(x), Q.composite(x)), + + # hermitian and antihermitian + Implies(Q.real(x), Q.hermitian(x)), + Implies(Q.imaginary(x), Q.antihermitian(x)), + Implies(Q.zero(x), Q.hermitian(x) | Q.antihermitian(x)), + + # define finity and infinity, and build extended real line + Exclusive(Q.infinite(x), Q.finite(x)), + Implies(Q.complex(x), Q.finite(x)), + Implies(Q.negative_infinite(x) | Q.positive_infinite(x), Q.infinite(x)), + + # commutativity + Implies(Q.finite(x) | Q.infinite(x), Q.commutative(x)), + ) + return fact + + +@cacheit +def get_matrix_facts(x = None): + """ + Facts between unary matrix predicates. + + Parameters + ========== + + x : Symbol, optional + Placeholder symbol for unary facts. Default is ``Symbol('x')``. + + Returns + ======= + + fact : Known facts in conjugated normal form. + + """ + if x is None: + x = Symbol('x') + + fact = And( + # matrices + Implies(Q.orthogonal(x), Q.positive_definite(x)), + Implies(Q.orthogonal(x), Q.unitary(x)), + Implies(Q.unitary(x) & Q.real_elements(x), Q.orthogonal(x)), + Implies(Q.unitary(x), Q.normal(x)), + Implies(Q.unitary(x), Q.invertible(x)), + Implies(Q.normal(x), Q.square(x)), + Implies(Q.diagonal(x), Q.normal(x)), + Implies(Q.positive_definite(x), Q.invertible(x)), + Implies(Q.diagonal(x), Q.upper_triangular(x)), + Implies(Q.diagonal(x), Q.lower_triangular(x)), + Implies(Q.lower_triangular(x), Q.triangular(x)), + Implies(Q.upper_triangular(x), Q.triangular(x)), + Implies(Q.triangular(x), Q.upper_triangular(x) | Q.lower_triangular(x)), + Implies(Q.upper_triangular(x) & Q.lower_triangular(x), Q.diagonal(x)), + Implies(Q.diagonal(x), Q.symmetric(x)), + Implies(Q.unit_triangular(x), Q.triangular(x)), + Implies(Q.invertible(x), Q.fullrank(x)), + Implies(Q.invertible(x), Q.square(x)), + Implies(Q.symmetric(x), Q.square(x)), + Implies(Q.fullrank(x) & Q.square(x), Q.invertible(x)), + Equivalent(Q.invertible(x), ~Q.singular(x)), + Implies(Q.integer_elements(x), Q.real_elements(x)), + Implies(Q.real_elements(x), Q.complex_elements(x)), + ) + return fact + + + +def generate_known_facts_dict(keys, fact): + """ + Computes and returns a dictionary which contains the relations between + unary predicates. + + Each key is a predicate, and item is two groups of predicates. + First group contains the predicates which are implied by the key, and + second group contains the predicates which are rejected by the key. + + All predicates in *keys* and *fact* must be unary and have same placeholder + symbol. + + Parameters + ========== + + keys : list of AppliedPredicate instances. + + fact : Fact between predicates in conjugated normal form. + + Examples + ======== + + >>> from sympy import Q, And, Implies + >>> from sympy.assumptions.facts import generate_known_facts_dict + >>> from sympy.abc import x + >>> keys = [Q.even(x), Q.odd(x), Q.zero(x)] + >>> fact = And(Implies(Q.even(x), ~Q.odd(x)), + ... Implies(Q.zero(x), Q.even(x))) + >>> generate_known_facts_dict(keys, fact) + {Q.even: ({Q.even}, {Q.odd}), + Q.odd: ({Q.odd}, {Q.even, Q.zero}), + Q.zero: ({Q.even, Q.zero}, {Q.odd})} + """ + fact_cnf = to_cnf(fact) + mapping = single_fact_lookup(keys, fact_cnf) + + ret = {} + for key, value in mapping.items(): + implied = set() + rejected = set() + for expr in value: + if isinstance(expr, AppliedPredicate): + implied.add(expr.function) + elif isinstance(expr, Not): + pred = expr.args[0] + rejected.add(pred.function) + ret[key.function] = (implied, rejected) + return ret + + +@cacheit +def get_known_facts_keys(): + """ + Return every unary predicates registered to ``Q``. + + This function is used to generate the keys for + ``generate_known_facts_dict``. + + """ + # exclude polyadic predicates + exclude = {Q.eq, Q.ne, Q.gt, Q.lt, Q.ge, Q.le} + + result = [] + for attr in Q.__class__.__dict__: + if attr.startswith('__'): + continue + pred = getattr(Q, attr) + if pred in exclude: + continue + result.append(pred) + return result + + +def single_fact_lookup(known_facts_keys, known_facts_cnf): + # Return the dictionary for quick lookup of single fact + mapping = {} + for key in known_facts_keys: + mapping[key] = {key} + for other_key in known_facts_keys: + if other_key != key: + if ask_full_inference(other_key, key, known_facts_cnf): + mapping[key].add(other_key) + if ask_full_inference(~other_key, key, known_facts_cnf): + mapping[key].add(~other_key) + return mapping + + +def ask_full_inference(proposition, assumptions, known_facts_cnf): + """ + Method for inferring properties about objects. + + """ + if not satisfiable(And(known_facts_cnf, assumptions, proposition)): + return False + if not satisfiable(And(known_facts_cnf, assumptions, Not(proposition))): + return True + return None diff --git a/.venv/lib/python3.13/site-packages/sympy/assumptions/lra_satask.py b/.venv/lib/python3.13/site-packages/sympy/assumptions/lra_satask.py new file mode 100644 index 0000000000000000000000000000000000000000..53afe3e5abe99109ec01a47f19f1a8a4c99c5628 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/assumptions/lra_satask.py @@ -0,0 +1,286 @@ +from sympy.assumptions.assume import global_assumptions +from sympy.assumptions.cnf import CNF, EncodedCNF +from sympy.assumptions.ask import Q +from sympy.logic.inference import satisfiable +from sympy.logic.algorithms.lra_theory import UnhandledInput, ALLOWED_PRED +from sympy.matrices.kind import MatrixKind +from sympy.core.kind import NumberKind +from sympy.assumptions.assume import AppliedPredicate +from sympy.core.mul import Mul +from sympy.core.singleton import S + + +def lra_satask(proposition, assumptions=True, context=global_assumptions): + """ + Function to evaluate the proposition with assumptions using SAT algorithm + in conjunction with an Linear Real Arithmetic theory solver. + + Used to handle inequalities. Should eventually be depreciated and combined + into satask, but infinity handling and other things need to be implemented + before that can happen. + """ + props = CNF.from_prop(proposition) + _props = CNF.from_prop(~proposition) + + cnf = CNF.from_prop(assumptions) + assumptions = EncodedCNF() + assumptions.from_cnf(cnf) + + context_cnf = CNF() + if context: + context_cnf = context_cnf.extend(context) + + assumptions.add_from_cnf(context_cnf) + + return check_satisfiability(props, _props, assumptions) + +# Some predicates such as Q.prime can't be handled by lra_satask. +# For example, (x > 0) & (x < 1) & Q.prime(x) is unsat but lra_satask would think it was sat. +# WHITE_LIST is a list of predicates that can always be handled. +WHITE_LIST = ALLOWED_PRED | {Q.positive, Q.negative, Q.zero, Q.nonzero, Q.nonpositive, Q.nonnegative, + Q.extended_positive, Q.extended_negative, Q.extended_nonpositive, + Q.extended_negative, Q.extended_nonzero, Q.negative_infinite, + Q.positive_infinite} + + +def check_satisfiability(prop, _prop, factbase): + sat_true = factbase.copy() + sat_false = factbase.copy() + sat_true.add_from_cnf(prop) + sat_false.add_from_cnf(_prop) + + all_pred, all_exprs = get_all_pred_and_expr_from_enc_cnf(sat_true) + + for pred in all_pred: + if pred.function not in WHITE_LIST and pred.function != Q.ne: + raise UnhandledInput(f"LRASolver: {pred} is an unhandled predicate") + for expr in all_exprs: + if expr.kind == MatrixKind(NumberKind): + raise UnhandledInput(f"LRASolver: {expr} is of MatrixKind") + if expr == S.NaN: + raise UnhandledInput("LRASolver: nan") + + # convert old assumptions into predicates and add them to sat_true and sat_false + # also check for unhandled predicates + for assm in extract_pred_from_old_assum(all_exprs): + n = len(sat_true.encoding) + if assm not in sat_true.encoding: + sat_true.encoding[assm] = n+1 + sat_true.data.append([sat_true.encoding[assm]]) + + n = len(sat_false.encoding) + if assm not in sat_false.encoding: + sat_false.encoding[assm] = n+1 + sat_false.data.append([sat_false.encoding[assm]]) + + + sat_true = _preprocess(sat_true) + sat_false = _preprocess(sat_false) + + can_be_true = satisfiable(sat_true, use_lra_theory=True) is not False + can_be_false = satisfiable(sat_false, use_lra_theory=True) is not False + + if can_be_true and can_be_false: + return None + + if can_be_true and not can_be_false: + return True + + if not can_be_true and can_be_false: + return False + + if not can_be_true and not can_be_false: + raise ValueError("Inconsistent assumptions") + + +def _preprocess(enc_cnf): + """ + Returns an encoded cnf with only Q.eq, Q.gt, Q.lt, + Q.ge, and Q.le predicate. + + Converts every unequality into a disjunction of strict + inequalities. For example, x != 3 would become + x < 3 OR x > 3. + + Also converts all negated Q.ne predicates into + equalities. + """ + + # loops through each literal in each clause + # to construct a new, preprocessed encodedCNF + + enc_cnf = enc_cnf.copy() + cur_enc = 1 + rev_encoding = {value: key for key, value in enc_cnf.encoding.items()} + + new_encoding = {} + new_data = [] + for clause in enc_cnf.data: + new_clause = [] + for lit in clause: + if lit == 0: + new_clause.append(lit) + new_encoding[lit] = False + continue + prop = rev_encoding[abs(lit)] + negated = lit < 0 + sign = (lit > 0) - (lit < 0) + + prop = _pred_to_binrel(prop) + + if not isinstance(prop, AppliedPredicate): + if prop not in new_encoding: + new_encoding[prop] = cur_enc + cur_enc += 1 + lit = new_encoding[prop] + new_clause.append(sign*lit) + continue + + + if negated and prop.function == Q.eq: + negated = False + prop = Q.ne(*prop.arguments) + + if prop.function == Q.ne: + arg1, arg2 = prop.arguments + if negated: + new_prop = Q.eq(arg1, arg2) + if new_prop not in new_encoding: + new_encoding[new_prop] = cur_enc + cur_enc += 1 + + new_enc = new_encoding[new_prop] + new_clause.append(new_enc) + continue + else: + new_props = (Q.gt(arg1, arg2), Q.lt(arg1, arg2)) + for new_prop in new_props: + if new_prop not in new_encoding: + new_encoding[new_prop] = cur_enc + cur_enc += 1 + + new_enc = new_encoding[new_prop] + new_clause.append(new_enc) + continue + + if prop.function == Q.eq and negated: + assert False + + if prop not in new_encoding: + new_encoding[prop] = cur_enc + cur_enc += 1 + new_clause.append(new_encoding[prop]*sign) + new_data.append(new_clause) + + assert len(new_encoding) >= cur_enc - 1 + + enc_cnf = EncodedCNF(new_data, new_encoding) + return enc_cnf + + +def _pred_to_binrel(pred): + if not isinstance(pred, AppliedPredicate): + return pred + + if pred.function in pred_to_pos_neg_zero: + f = pred_to_pos_neg_zero[pred.function] + if f is False: + return False + pred = f(pred.arguments[0]) + + if pred.function == Q.positive: + pred = Q.gt(pred.arguments[0], 0) + elif pred.function == Q.negative: + pred = Q.lt(pred.arguments[0], 0) + elif pred.function == Q.zero: + pred = Q.eq(pred.arguments[0], 0) + elif pred.function == Q.nonpositive: + pred = Q.le(pred.arguments[0], 0) + elif pred.function == Q.nonnegative: + pred = Q.ge(pred.arguments[0], 0) + elif pred.function == Q.nonzero: + pred = Q.ne(pred.arguments[0], 0) + + return pred + +pred_to_pos_neg_zero = { + Q.extended_positive: Q.positive, + Q.extended_negative: Q.negative, + Q.extended_nonpositive: Q.nonpositive, + Q.extended_negative: Q.negative, + Q.extended_nonzero: Q.nonzero, + Q.negative_infinite: False, + Q.positive_infinite: False +} + +def get_all_pred_and_expr_from_enc_cnf(enc_cnf): + all_exprs = set() + all_pred = set() + for pred in enc_cnf.encoding.keys(): + if isinstance(pred, AppliedPredicate): + all_pred.add(pred) + all_exprs.update(pred.arguments) + + return all_pred, all_exprs + +def extract_pred_from_old_assum(all_exprs): + """ + Returns a list of relevant new assumption predicate + based on any old assumptions. + + Raises an UnhandledInput exception if any of the assumptions are + unhandled. + + Ignored predicate: + - commutative + - complex + - algebraic + - transcendental + - extended_real + - real + - all matrix predicate + - rational + - irrational + + Example + ======= + >>> from sympy.assumptions.lra_satask import extract_pred_from_old_assum + >>> from sympy import symbols + >>> x, y = symbols("x y", positive=True) + >>> extract_pred_from_old_assum([x, y, 2]) + [Q.positive(x), Q.positive(y)] + """ + ret = [] + for expr in all_exprs: + if not hasattr(expr, "free_symbols"): + continue + if len(expr.free_symbols) == 0: + continue + + if expr.is_real is not True: + raise UnhandledInput(f"LRASolver: {expr} must be real") + # test for I times imaginary variable; such expressions are considered real + if isinstance(expr, Mul) and any(arg.is_real is not True for arg in expr.args): + raise UnhandledInput(f"LRASolver: {expr} must be real") + + if expr.is_integer == True and expr.is_zero != True: + raise UnhandledInput(f"LRASolver: {expr} is an integer") + if expr.is_integer == False: + raise UnhandledInput(f"LRASolver: {expr} can't be an integer") + if expr.is_rational == False: + raise UnhandledInput(f"LRASolver: {expr} is irational") + + if expr.is_zero: + ret.append(Q.zero(expr)) + elif expr.is_positive: + ret.append(Q.positive(expr)) + elif expr.is_negative: + ret.append(Q.negative(expr)) + elif expr.is_nonzero: + ret.append(Q.nonzero(expr)) + elif expr.is_nonpositive: + ret.append(Q.nonpositive(expr)) + elif expr.is_nonnegative: + ret.append(Q.nonnegative(expr)) + + return ret diff --git a/.venv/lib/python3.13/site-packages/sympy/assumptions/refine.py b/.venv/lib/python3.13/site-packages/sympy/assumptions/refine.py new file mode 100644 index 0000000000000000000000000000000000000000..c36a4e1cdb40f1b59a96f60a3b36182b587920fa --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/assumptions/refine.py @@ -0,0 +1,405 @@ +from __future__ import annotations +from typing import Callable + +from sympy.core import S, Add, Expr, Basic, Mul, Pow, Rational +from sympy.core.logic import fuzzy_not +from sympy.logic.boolalg import Boolean + +from sympy.assumptions import ask, Q # type: ignore + + +def refine(expr, assumptions=True): + """ + Simplify an expression using assumptions. + + Explanation + =========== + + Unlike :func:`~.simplify` which performs structural simplification + without any assumption, this function transforms the expression into + the form which is only valid under certain assumptions. Note that + ``simplify()`` is generally not done in refining process. + + Refining boolean expression involves reducing it to ``S.true`` or + ``S.false``. Unlike :func:`~.ask`, the expression will not be reduced + if the truth value cannot be determined. + + Examples + ======== + + >>> from sympy import refine, sqrt, Q + >>> from sympy.abc import x + >>> refine(sqrt(x**2), Q.real(x)) + Abs(x) + >>> refine(sqrt(x**2), Q.positive(x)) + x + + >>> refine(Q.real(x), Q.positive(x)) + True + >>> refine(Q.positive(x), Q.real(x)) + Q.positive(x) + + See Also + ======== + + sympy.simplify.simplify.simplify : Structural simplification without assumptions. + sympy.assumptions.ask.ask : Query for boolean expressions using assumptions. + """ + if not isinstance(expr, Basic): + return expr + + if not expr.is_Atom: + args = [refine(arg, assumptions) for arg in expr.args] + # TODO: this will probably not work with Integral or Polynomial + expr = expr.func(*args) + if hasattr(expr, '_eval_refine'): + ref_expr = expr._eval_refine(assumptions) + if ref_expr is not None: + return ref_expr + name = expr.__class__.__name__ + handler = handlers_dict.get(name, None) + if handler is None: + return expr + new_expr = handler(expr, assumptions) + if (new_expr is None) or (expr == new_expr): + return expr + if not isinstance(new_expr, Expr): + return new_expr + return refine(new_expr, assumptions) + + +def refine_abs(expr, assumptions): + """ + Handler for the absolute value. + + Examples + ======== + + >>> from sympy import Q, Abs + >>> from sympy.assumptions.refine import refine_abs + >>> from sympy.abc import x + >>> refine_abs(Abs(x), Q.real(x)) + >>> refine_abs(Abs(x), Q.positive(x)) + x + >>> refine_abs(Abs(x), Q.negative(x)) + -x + + """ + from sympy.functions.elementary.complexes import Abs + arg = expr.args[0] + if ask(Q.real(arg), assumptions) and \ + fuzzy_not(ask(Q.negative(arg), assumptions)): + # if it's nonnegative + return arg + if ask(Q.negative(arg), assumptions): + return -arg + # arg is Mul + if isinstance(arg, Mul): + r = [refine(abs(a), assumptions) for a in arg.args] + non_abs = [] + in_abs = [] + for i in r: + if isinstance(i, Abs): + in_abs.append(i.args[0]) + else: + non_abs.append(i) + return Mul(*non_abs) * Abs(Mul(*in_abs)) + + +def refine_Pow(expr, assumptions): + """ + Handler for instances of Pow. + + Examples + ======== + + >>> from sympy import Q + >>> from sympy.assumptions.refine import refine_Pow + >>> from sympy.abc import x,y,z + >>> refine_Pow((-1)**x, Q.real(x)) + >>> refine_Pow((-1)**x, Q.even(x)) + 1 + >>> refine_Pow((-1)**x, Q.odd(x)) + -1 + + For powers of -1, even parts of the exponent can be simplified: + + >>> refine_Pow((-1)**(x+y), Q.even(x)) + (-1)**y + >>> refine_Pow((-1)**(x+y+z), Q.odd(x) & Q.odd(z)) + (-1)**y + >>> refine_Pow((-1)**(x+y+2), Q.odd(x)) + (-1)**(y + 1) + >>> refine_Pow((-1)**(x+3), True) + (-1)**(x + 1) + + """ + from sympy.functions.elementary.complexes import Abs + from sympy.functions import sign + if isinstance(expr.base, Abs): + if ask(Q.real(expr.base.args[0]), assumptions) and \ + ask(Q.even(expr.exp), assumptions): + return expr.base.args[0] ** expr.exp + if ask(Q.real(expr.base), assumptions): + if expr.base.is_number: + if ask(Q.even(expr.exp), assumptions): + return abs(expr.base) ** expr.exp + if ask(Q.odd(expr.exp), assumptions): + return sign(expr.base) * abs(expr.base) ** expr.exp + if isinstance(expr.exp, Rational): + if isinstance(expr.base, Pow): + return abs(expr.base.base) ** (expr.base.exp * expr.exp) + + if expr.base is S.NegativeOne: + if expr.exp.is_Add: + + old = expr + + # For powers of (-1) we can remove + # - even terms + # - pairs of odd terms + # - a single odd term + 1 + # - A numerical constant N can be replaced with mod(N,2) + + coeff, terms = expr.exp.as_coeff_add() + terms = set(terms) + even_terms = set() + odd_terms = set() + initial_number_of_terms = len(terms) + + for t in terms: + if ask(Q.even(t), assumptions): + even_terms.add(t) + elif ask(Q.odd(t), assumptions): + odd_terms.add(t) + + terms -= even_terms + if len(odd_terms) % 2: + terms -= odd_terms + new_coeff = (coeff + S.One) % 2 + else: + terms -= odd_terms + new_coeff = coeff % 2 + + if new_coeff != coeff or len(terms) < initial_number_of_terms: + terms.add(new_coeff) + expr = expr.base**(Add(*terms)) + + # Handle (-1)**((-1)**n/2 + m/2) + e2 = 2*expr.exp + if ask(Q.even(e2), assumptions): + if e2.could_extract_minus_sign(): + e2 *= expr.base + if e2.is_Add: + i, p = e2.as_two_terms() + if p.is_Pow and p.base is S.NegativeOne: + if ask(Q.integer(p.exp), assumptions): + i = (i + 1)/2 + if ask(Q.even(i), assumptions): + return expr.base**p.exp + elif ask(Q.odd(i), assumptions): + return expr.base**(p.exp + 1) + else: + return expr.base**(p.exp + i) + + if old != expr: + return expr + + +def refine_atan2(expr, assumptions): + """ + Handler for the atan2 function. + + Examples + ======== + + >>> from sympy import Q, atan2 + >>> from sympy.assumptions.refine import refine_atan2 + >>> from sympy.abc import x, y + >>> refine_atan2(atan2(y,x), Q.real(y) & Q.positive(x)) + atan(y/x) + >>> refine_atan2(atan2(y,x), Q.negative(y) & Q.negative(x)) + atan(y/x) - pi + >>> refine_atan2(atan2(y,x), Q.positive(y) & Q.negative(x)) + atan(y/x) + pi + >>> refine_atan2(atan2(y,x), Q.zero(y) & Q.negative(x)) + pi + >>> refine_atan2(atan2(y,x), Q.positive(y) & Q.zero(x)) + pi/2 + >>> refine_atan2(atan2(y,x), Q.negative(y) & Q.zero(x)) + -pi/2 + >>> refine_atan2(atan2(y,x), Q.zero(y) & Q.zero(x)) + nan + """ + from sympy.functions.elementary.trigonometric import atan + y, x = expr.args + if ask(Q.real(y) & Q.positive(x), assumptions): + return atan(y / x) + elif ask(Q.negative(y) & Q.negative(x), assumptions): + return atan(y / x) - S.Pi + elif ask(Q.positive(y) & Q.negative(x), assumptions): + return atan(y / x) + S.Pi + elif ask(Q.zero(y) & Q.negative(x), assumptions): + return S.Pi + elif ask(Q.positive(y) & Q.zero(x), assumptions): + return S.Pi/2 + elif ask(Q.negative(y) & Q.zero(x), assumptions): + return -S.Pi/2 + elif ask(Q.zero(y) & Q.zero(x), assumptions): + return S.NaN + else: + return expr + + +def refine_re(expr, assumptions): + """ + Handler for real part. + + Examples + ======== + + >>> from sympy.assumptions.refine import refine_re + >>> from sympy import Q, re + >>> from sympy.abc import x + >>> refine_re(re(x), Q.real(x)) + x + >>> refine_re(re(x), Q.imaginary(x)) + 0 + """ + arg = expr.args[0] + if ask(Q.real(arg), assumptions): + return arg + if ask(Q.imaginary(arg), assumptions): + return S.Zero + return _refine_reim(expr, assumptions) + + +def refine_im(expr, assumptions): + """ + Handler for imaginary part. + + Explanation + =========== + + >>> from sympy.assumptions.refine import refine_im + >>> from sympy import Q, im + >>> from sympy.abc import x + >>> refine_im(im(x), Q.real(x)) + 0 + >>> refine_im(im(x), Q.imaginary(x)) + -I*x + """ + arg = expr.args[0] + if ask(Q.real(arg), assumptions): + return S.Zero + if ask(Q.imaginary(arg), assumptions): + return - S.ImaginaryUnit * arg + return _refine_reim(expr, assumptions) + +def refine_arg(expr, assumptions): + """ + Handler for complex argument + + Explanation + =========== + + >>> from sympy.assumptions.refine import refine_arg + >>> from sympy import Q, arg + >>> from sympy.abc import x + >>> refine_arg(arg(x), Q.positive(x)) + 0 + >>> refine_arg(arg(x), Q.negative(x)) + pi + """ + rg = expr.args[0] + if ask(Q.positive(rg), assumptions): + return S.Zero + if ask(Q.negative(rg), assumptions): + return S.Pi + return None + + +def _refine_reim(expr, assumptions): + # Helper function for refine_re & refine_im + expanded = expr.expand(complex = True) + if expanded != expr: + refined = refine(expanded, assumptions) + if refined != expanded: + return refined + # Best to leave the expression as is + return None + + +def refine_sign(expr, assumptions): + """ + Handler for sign. + + Examples + ======== + + >>> from sympy.assumptions.refine import refine_sign + >>> from sympy import Symbol, Q, sign, im + >>> x = Symbol('x', real = True) + >>> expr = sign(x) + >>> refine_sign(expr, Q.positive(x) & Q.nonzero(x)) + 1 + >>> refine_sign(expr, Q.negative(x) & Q.nonzero(x)) + -1 + >>> refine_sign(expr, Q.zero(x)) + 0 + >>> y = Symbol('y', imaginary = True) + >>> expr = sign(y) + >>> refine_sign(expr, Q.positive(im(y))) + I + >>> refine_sign(expr, Q.negative(im(y))) + -I + """ + arg = expr.args[0] + if ask(Q.zero(arg), assumptions): + return S.Zero + if ask(Q.real(arg)): + if ask(Q.positive(arg), assumptions): + return S.One + if ask(Q.negative(arg), assumptions): + return S.NegativeOne + if ask(Q.imaginary(arg)): + arg_re, arg_im = arg.as_real_imag() + if ask(Q.positive(arg_im), assumptions): + return S.ImaginaryUnit + if ask(Q.negative(arg_im), assumptions): + return -S.ImaginaryUnit + return expr + + +def refine_matrixelement(expr, assumptions): + """ + Handler for symmetric part. + + Examples + ======== + + >>> from sympy.assumptions.refine import refine_matrixelement + >>> from sympy import MatrixSymbol, Q + >>> X = MatrixSymbol('X', 3, 3) + >>> refine_matrixelement(X[0, 1], Q.symmetric(X)) + X[0, 1] + >>> refine_matrixelement(X[1, 0], Q.symmetric(X)) + X[0, 1] + """ + from sympy.matrices.expressions.matexpr import MatrixElement + matrix, i, j = expr.args + if ask(Q.symmetric(matrix), assumptions): + if (i - j).could_extract_minus_sign(): + return expr + return MatrixElement(matrix, j, i) + +handlers_dict: dict[str, Callable[[Expr, Boolean], Expr]] = { + 'Abs': refine_abs, + 'Pow': refine_Pow, + 'atan2': refine_atan2, + 're': refine_re, + 'im': refine_im, + 'arg': refine_arg, + 'sign': refine_sign, + 'MatrixElement': refine_matrixelement +} diff --git a/.venv/lib/python3.13/site-packages/sympy/assumptions/satask.py b/.venv/lib/python3.13/site-packages/sympy/assumptions/satask.py new file mode 100644 index 0000000000000000000000000000000000000000..ffc13f6d3bc3fb7f573c8d5d0564b780440c1a8c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/assumptions/satask.py @@ -0,0 +1,369 @@ +""" +Module to evaluate the proposition with assumptions using SAT algorithm. +""" + +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.core.kind import NumberKind, UndefinedKind +from sympy.assumptions.ask_generated import get_all_known_matrix_facts, get_all_known_number_facts +from sympy.assumptions.assume import global_assumptions, AppliedPredicate +from sympy.assumptions.sathandlers import class_fact_registry +from sympy.core import oo +from sympy.logic.inference import satisfiable +from sympy.assumptions.cnf import CNF, EncodedCNF +from sympy.matrices.kind import MatrixKind + + +def satask(proposition, assumptions=True, context=global_assumptions, + use_known_facts=True, iterations=oo): + """ + Function to evaluate the proposition with assumptions using SAT algorithm. + + This function extracts every fact relevant to the expressions composing + proposition and assumptions. For example, if a predicate containing + ``Abs(x)`` is proposed, then ``Q.zero(Abs(x)) | Q.positive(Abs(x))`` + will be found and passed to SAT solver because ``Q.nonnegative`` is + registered as a fact for ``Abs``. + + Proposition is evaluated to ``True`` or ``False`` if the truth value can be + determined. If not, ``None`` is returned. + + Parameters + ========== + + proposition : Any boolean expression. + Proposition which will be evaluated to boolean value. + + assumptions : Any boolean expression, optional. + Local assumptions to evaluate the *proposition*. + + context : AssumptionsContext, optional. + Default assumptions to evaluate the *proposition*. By default, + this is ``sympy.assumptions.global_assumptions`` variable. + + use_known_facts : bool, optional. + If ``True``, facts from ``sympy.assumptions.ask_generated`` + module are passed to SAT solver as well. + + iterations : int, optional. + Number of times that relevant facts are recursively extracted. + Default is infinite times until no new fact is found. + + Returns + ======= + + ``True``, ``False``, or ``None`` + + Examples + ======== + + >>> from sympy import Abs, Q + >>> from sympy.assumptions.satask import satask + >>> from sympy.abc import x + >>> satask(Q.zero(Abs(x)), Q.zero(x)) + True + + """ + props = CNF.from_prop(proposition) + _props = CNF.from_prop(~proposition) + + assumptions = CNF.from_prop(assumptions) + + context_cnf = CNF() + if context: + context_cnf = context_cnf.extend(context) + + sat = get_all_relevant_facts(props, assumptions, context_cnf, + use_known_facts=use_known_facts, iterations=iterations) + sat.add_from_cnf(assumptions) + if context: + sat.add_from_cnf(context_cnf) + + return check_satisfiability(props, _props, sat) + + +def check_satisfiability(prop, _prop, factbase): + sat_true = factbase.copy() + sat_false = factbase.copy() + sat_true.add_from_cnf(prop) + sat_false.add_from_cnf(_prop) + can_be_true = satisfiable(sat_true) + can_be_false = satisfiable(sat_false) + + if can_be_true and can_be_false: + return None + + if can_be_true and not can_be_false: + return True + + if not can_be_true and can_be_false: + return False + + if not can_be_true and not can_be_false: + # TODO: Run additional checks to see which combination of the + # assumptions, global_assumptions, and relevant_facts are + # inconsistent. + raise ValueError("Inconsistent assumptions") + + +def extract_predargs(proposition, assumptions=None, context=None): + """ + Extract every expression in the argument of predicates from *proposition*, + *assumptions* and *context*. + + Parameters + ========== + + proposition : sympy.assumptions.cnf.CNF + + assumptions : sympy.assumptions.cnf.CNF, optional. + + context : sympy.assumptions.cnf.CNF, optional. + CNF generated from assumptions context. + + Examples + ======== + + >>> from sympy import Q, Abs + >>> from sympy.assumptions.cnf import CNF + >>> from sympy.assumptions.satask import extract_predargs + >>> from sympy.abc import x, y + >>> props = CNF.from_prop(Q.zero(Abs(x*y))) + >>> assump = CNF.from_prop(Q.zero(x) & Q.zero(y)) + >>> extract_predargs(props, assump) + {x, y, Abs(x*y)} + + """ + req_keys = find_symbols(proposition) + keys = proposition.all_predicates() + # XXX: We need this since True/False are not Basic + lkeys = set() + if assumptions: + lkeys |= assumptions.all_predicates() + if context: + lkeys |= context.all_predicates() + + lkeys = lkeys - {S.true, S.false} + tmp_keys = None + while tmp_keys != set(): + tmp = set() + for l in lkeys: + syms = find_symbols(l) + if (syms & req_keys) != set(): + tmp |= syms + tmp_keys = tmp - req_keys + req_keys |= tmp_keys + keys |= {l for l in lkeys if find_symbols(l) & req_keys != set()} + + exprs = set() + for key in keys: + if isinstance(key, AppliedPredicate): + exprs |= set(key.arguments) + else: + exprs.add(key) + return exprs + +def find_symbols(pred): + """ + Find every :obj:`~.Symbol` in *pred*. + + Parameters + ========== + + pred : sympy.assumptions.cnf.CNF, or any Expr. + + """ + if isinstance(pred, CNF): + symbols = set() + for a in pred.all_predicates(): + symbols |= find_symbols(a) + return symbols + return pred.atoms(Symbol) + + +def get_relevant_clsfacts(exprs, relevant_facts=None): + """ + Extract relevant facts from the items in *exprs*. Facts are defined in + ``assumptions.sathandlers`` module. + + This function is recursively called by ``get_all_relevant_facts()``. + + Parameters + ========== + + exprs : set + Expressions whose relevant facts are searched. + + relevant_facts : sympy.assumptions.cnf.CNF, optional. + Pre-discovered relevant facts. + + Returns + ======= + + exprs : set + Candidates for next relevant fact searching. + + relevant_facts : sympy.assumptions.cnf.CNF + Updated relevant facts. + + Examples + ======== + + Here, we will see how facts relevant to ``Abs(x*y)`` are recursively + extracted. On the first run, set containing the expression is passed + without pre-discovered relevant facts. The result is a set containing + candidates for next run, and ``CNF()`` instance containing facts + which are relevant to ``Abs`` and its argument. + + >>> from sympy import Abs + >>> from sympy.assumptions.satask import get_relevant_clsfacts + >>> from sympy.abc import x, y + >>> exprs = {Abs(x*y)} + >>> exprs, facts = get_relevant_clsfacts(exprs) + >>> exprs + {x*y} + >>> facts.clauses #doctest: +SKIP + {frozenset({Literal(Q.odd(Abs(x*y)), False), Literal(Q.odd(x*y), True)}), + frozenset({Literal(Q.zero(Abs(x*y)), False), Literal(Q.zero(x*y), True)}), + frozenset({Literal(Q.even(Abs(x*y)), False), Literal(Q.even(x*y), True)}), + frozenset({Literal(Q.zero(Abs(x*y)), True), Literal(Q.zero(x*y), False)}), + frozenset({Literal(Q.even(Abs(x*y)), False), + Literal(Q.odd(Abs(x*y)), False), + Literal(Q.odd(x*y), True)}), + frozenset({Literal(Q.even(Abs(x*y)), False), + Literal(Q.even(x*y), True), + Literal(Q.odd(Abs(x*y)), False)}), + frozenset({Literal(Q.positive(Abs(x*y)), False), + Literal(Q.zero(Abs(x*y)), False)})} + + We pass the first run's results to the second run, and get the expressions + for next run and updated facts. + + >>> exprs, facts = get_relevant_clsfacts(exprs, relevant_facts=facts) + >>> exprs + {x, y} + + On final run, no more candidate is returned thus we know that all + relevant facts are successfully retrieved. + + >>> exprs, facts = get_relevant_clsfacts(exprs, relevant_facts=facts) + >>> exprs + set() + + """ + if not relevant_facts: + relevant_facts = CNF() + + newexprs = set() + for expr in exprs: + for fact in class_fact_registry(expr): + newfact = CNF.to_CNF(fact) + relevant_facts = relevant_facts._and(newfact) + for key in newfact.all_predicates(): + if isinstance(key, AppliedPredicate): + newexprs |= set(key.arguments) + + return newexprs - exprs, relevant_facts + + +def get_all_relevant_facts(proposition, assumptions, context, + use_known_facts=True, iterations=oo): + """ + Extract all relevant facts from *proposition* and *assumptions*. + + This function extracts the facts by recursively calling + ``get_relevant_clsfacts()``. Extracted facts are converted to + ``EncodedCNF`` and returned. + + Parameters + ========== + + proposition : sympy.assumptions.cnf.CNF + CNF generated from proposition expression. + + assumptions : sympy.assumptions.cnf.CNF + CNF generated from assumption expression. + + context : sympy.assumptions.cnf.CNF + CNF generated from assumptions context. + + use_known_facts : bool, optional. + If ``True``, facts from ``sympy.assumptions.ask_generated`` + module are encoded as well. + + iterations : int, optional. + Number of times that relevant facts are recursively extracted. + Default is infinite times until no new fact is found. + + Returns + ======= + + sympy.assumptions.cnf.EncodedCNF + + Examples + ======== + + >>> from sympy import Q + >>> from sympy.assumptions.cnf import CNF + >>> from sympy.assumptions.satask import get_all_relevant_facts + >>> from sympy.abc import x, y + >>> props = CNF.from_prop(Q.nonzero(x*y)) + >>> assump = CNF.from_prop(Q.nonzero(x)) + >>> context = CNF.from_prop(Q.nonzero(y)) + >>> get_all_relevant_facts(props, assump, context) #doctest: +SKIP + + + """ + # The relevant facts might introduce new keys, e.g., Q.zero(x*y) will + # introduce the keys Q.zero(x) and Q.zero(y), so we need to run it until + # we stop getting new things. Hopefully this strategy won't lead to an + # infinite loop in the future. + i = 0 + relevant_facts = CNF() + all_exprs = set() + while True: + if i == 0: + exprs = extract_predargs(proposition, assumptions, context) + all_exprs |= exprs + exprs, relevant_facts = get_relevant_clsfacts(exprs, relevant_facts) + i += 1 + if i >= iterations: + break + if not exprs: + break + + if use_known_facts: + known_facts_CNF = CNF() + + if any(expr.kind == MatrixKind(NumberKind) for expr in all_exprs): + known_facts_CNF.add_clauses(get_all_known_matrix_facts()) + # check for undefinedKind since kind system isn't fully implemented + if any(((expr.kind == NumberKind) or (expr.kind == UndefinedKind)) for expr in all_exprs): + known_facts_CNF.add_clauses(get_all_known_number_facts()) + + kf_encoded = EncodedCNF() + kf_encoded.from_cnf(known_facts_CNF) + + def translate_literal(lit, delta): + if lit > 0: + return lit + delta + else: + return lit - delta + + def translate_data(data, delta): + return [{translate_literal(i, delta) for i in clause} for clause in data] + data = [] + symbols = [] + n_lit = len(kf_encoded.symbols) + for i, expr in enumerate(all_exprs): + symbols += [pred(expr) for pred in kf_encoded.symbols] + data += translate_data(kf_encoded.data, i * n_lit) + + encoding = dict(list(zip(symbols, range(1, len(symbols)+1)))) + ctx = EncodedCNF(data, encoding) + else: + ctx = EncodedCNF() + + ctx.add_from_cnf(relevant_facts) + + return ctx diff --git a/.venv/lib/python3.13/site-packages/sympy/assumptions/sathandlers.py b/.venv/lib/python3.13/site-packages/sympy/assumptions/sathandlers.py new file mode 100644 index 0000000000000000000000000000000000000000..a11199eb0e547187ab280c18196c0259c178e004 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/assumptions/sathandlers.py @@ -0,0 +1,322 @@ +from collections import defaultdict + +from sympy.assumptions.ask import Q +from sympy.core import (Add, Mul, Pow, Number, NumberSymbol, Symbol) +from sympy.core.numbers import ImaginaryUnit +from sympy.functions.elementary.complexes import Abs +from sympy.logic.boolalg import (Equivalent, And, Or, Implies) +from sympy.matrices.expressions import MatMul + +# APIs here may be subject to change + + +### Helper functions ### + +def allargs(symbol, fact, expr): + """ + Apply all arguments of the expression to the fact structure. + + Parameters + ========== + + symbol : Symbol + A placeholder symbol. + + fact : Boolean + Resulting ``Boolean`` expression. + + expr : Expr + + Examples + ======== + + >>> from sympy import Q + >>> from sympy.assumptions.sathandlers import allargs + >>> from sympy.abc import x, y + >>> allargs(x, Q.negative(x) | Q.positive(x), x*y) + (Q.negative(x) | Q.positive(x)) & (Q.negative(y) | Q.positive(y)) + + """ + return And(*[fact.subs(symbol, arg) for arg in expr.args]) + + +def anyarg(symbol, fact, expr): + """ + Apply any argument of the expression to the fact structure. + + Parameters + ========== + + symbol : Symbol + A placeholder symbol. + + fact : Boolean + Resulting ``Boolean`` expression. + + expr : Expr + + Examples + ======== + + >>> from sympy import Q + >>> from sympy.assumptions.sathandlers import anyarg + >>> from sympy.abc import x, y + >>> anyarg(x, Q.negative(x) & Q.positive(x), x*y) + (Q.negative(x) & Q.positive(x)) | (Q.negative(y) & Q.positive(y)) + + """ + return Or(*[fact.subs(symbol, arg) for arg in expr.args]) + + +def exactlyonearg(symbol, fact, expr): + """ + Apply exactly one argument of the expression to the fact structure. + + Parameters + ========== + + symbol : Symbol + A placeholder symbol. + + fact : Boolean + Resulting ``Boolean`` expression. + + expr : Expr + + Examples + ======== + + >>> from sympy import Q + >>> from sympy.assumptions.sathandlers import exactlyonearg + >>> from sympy.abc import x, y + >>> exactlyonearg(x, Q.positive(x), x*y) + (Q.positive(x) & ~Q.positive(y)) | (Q.positive(y) & ~Q.positive(x)) + + """ + pred_args = [fact.subs(symbol, arg) for arg in expr.args] + res = Or(*[And(pred_args[i], *[~lit for lit in pred_args[:i] + + pred_args[i+1:]]) for i in range(len(pred_args))]) + return res + + +### Fact registry ### + +class ClassFactRegistry: + """ + Register handlers against classes. + + Explanation + =========== + + ``register`` method registers the handler function for a class. Here, + handler function should return a single fact. ``multiregister`` method + registers the handler function for multiple classes. Here, handler function + should return a container of multiple facts. + + ``registry(expr)`` returns a set of facts for *expr*. + + Examples + ======== + + Here, we register the facts for ``Abs``. + + >>> from sympy import Abs, Equivalent, Q + >>> from sympy.assumptions.sathandlers import ClassFactRegistry + >>> reg = ClassFactRegistry() + >>> @reg.register(Abs) + ... def f1(expr): + ... return Q.nonnegative(expr) + >>> @reg.register(Abs) + ... def f2(expr): + ... arg = expr.args[0] + ... return Equivalent(~Q.zero(arg), ~Q.zero(expr)) + + Calling the registry with expression returns the defined facts for the + expression. + + >>> from sympy.abc import x + >>> reg(Abs(x)) + {Q.nonnegative(Abs(x)), Equivalent(~Q.zero(x), ~Q.zero(Abs(x)))} + + Multiple facts can be registered at once by ``multiregister`` method. + + >>> reg2 = ClassFactRegistry() + >>> @reg2.multiregister(Abs) + ... def _(expr): + ... arg = expr.args[0] + ... return [Q.even(arg) >> Q.even(expr), Q.odd(arg) >> Q.odd(expr)] + >>> reg2(Abs(x)) + {Implies(Q.even(x), Q.even(Abs(x))), Implies(Q.odd(x), Q.odd(Abs(x)))} + + """ + def __init__(self): + self.singlefacts = defaultdict(frozenset) + self.multifacts = defaultdict(frozenset) + + def register(self, cls): + def _(func): + self.singlefacts[cls] |= {func} + return func + return _ + + def multiregister(self, *classes): + def _(func): + for cls in classes: + self.multifacts[cls] |= {func} + return func + return _ + + def __getitem__(self, key): + ret1 = self.singlefacts[key] + for k in self.singlefacts: + if issubclass(key, k): + ret1 |= self.singlefacts[k] + + ret2 = self.multifacts[key] + for k in self.multifacts: + if issubclass(key, k): + ret2 |= self.multifacts[k] + + return ret1, ret2 + + def __call__(self, expr): + ret = set() + + handlers1, handlers2 = self[type(expr)] + + ret.update(h(expr) for h in handlers1) + for h in handlers2: + ret.update(h(expr)) + return ret + +class_fact_registry = ClassFactRegistry() + + + +### Class fact registration ### + +x = Symbol('x') + +## Abs ## + +@class_fact_registry.multiregister(Abs) +def _(expr): + arg = expr.args[0] + return [Q.nonnegative(expr), + Equivalent(~Q.zero(arg), ~Q.zero(expr)), + Q.even(arg) >> Q.even(expr), + Q.odd(arg) >> Q.odd(expr), + Q.integer(arg) >> Q.integer(expr), + ] + + +### Add ## + +@class_fact_registry.multiregister(Add) +def _(expr): + return [allargs(x, Q.positive(x), expr) >> Q.positive(expr), + allargs(x, Q.negative(x), expr) >> Q.negative(expr), + allargs(x, Q.real(x), expr) >> Q.real(expr), + allargs(x, Q.rational(x), expr) >> Q.rational(expr), + allargs(x, Q.integer(x), expr) >> Q.integer(expr), + exactlyonearg(x, ~Q.integer(x), expr) >> ~Q.integer(expr), + ] + +@class_fact_registry.register(Add) +def _(expr): + allargs_real = allargs(x, Q.real(x), expr) + onearg_irrational = exactlyonearg(x, Q.irrational(x), expr) + return Implies(allargs_real, Implies(onearg_irrational, Q.irrational(expr))) + + +### Mul ### + +@class_fact_registry.multiregister(Mul) +def _(expr): + return [Equivalent(Q.zero(expr), anyarg(x, Q.zero(x), expr)), + allargs(x, Q.positive(x), expr) >> Q.positive(expr), + allargs(x, Q.real(x), expr) >> Q.real(expr), + allargs(x, Q.rational(x), expr) >> Q.rational(expr), + allargs(x, Q.integer(x), expr) >> Q.integer(expr), + exactlyonearg(x, ~Q.rational(x), expr) >> ~Q.integer(expr), + allargs(x, Q.commutative(x), expr) >> Q.commutative(expr), + ] + +@class_fact_registry.register(Mul) +def _(expr): + # Implicitly assumes Mul has more than one arg + # Would be allargs(x, Q.prime(x) | Q.composite(x)) except 1 is composite + # More advanced prime assumptions will require inequalities, as 1 provides + # a corner case. + allargs_prime = allargs(x, Q.prime(x), expr) + return Implies(allargs_prime, ~Q.prime(expr)) + +@class_fact_registry.register(Mul) +def _(expr): + # General Case: Odd number of imaginary args implies mul is imaginary(To be implemented) + allargs_imag_or_real = allargs(x, Q.imaginary(x) | Q.real(x), expr) + onearg_imaginary = exactlyonearg(x, Q.imaginary(x), expr) + return Implies(allargs_imag_or_real, Implies(onearg_imaginary, Q.imaginary(expr))) + +@class_fact_registry.register(Mul) +def _(expr): + allargs_real = allargs(x, Q.real(x), expr) + onearg_irrational = exactlyonearg(x, Q.irrational(x), expr) + return Implies(allargs_real, Implies(onearg_irrational, Q.irrational(expr))) + +@class_fact_registry.register(Mul) +def _(expr): + # Including the integer qualification means we don't need to add any facts + # for odd, since the assumptions already know that every integer is + # exactly one of even or odd. + allargs_integer = allargs(x, Q.integer(x), expr) + anyarg_even = anyarg(x, Q.even(x), expr) + return Implies(allargs_integer, Equivalent(anyarg_even, Q.even(expr))) + + +### MatMul ### + +@class_fact_registry.register(MatMul) +def _(expr): + allargs_square = allargs(x, Q.square(x), expr) + allargs_invertible = allargs(x, Q.invertible(x), expr) + return Implies(allargs_square, Equivalent(Q.invertible(expr), allargs_invertible)) + + +### Pow ### + +@class_fact_registry.multiregister(Pow) +def _(expr): + base, exp = expr.base, expr.exp + return [ + (Q.real(base) & Q.even(exp) & Q.nonnegative(exp)) >> Q.nonnegative(expr), + (Q.nonnegative(base) & Q.odd(exp) & Q.nonnegative(exp)) >> Q.nonnegative(expr), + (Q.nonpositive(base) & Q.odd(exp) & Q.nonnegative(exp)) >> Q.nonpositive(expr), + Equivalent(Q.zero(expr), Q.zero(base) & Q.positive(exp)) + ] + + +### Numbers ### + +_old_assump_getters = { + Q.positive: lambda o: o.is_positive, + Q.zero: lambda o: o.is_zero, + Q.negative: lambda o: o.is_negative, + Q.rational: lambda o: o.is_rational, + Q.irrational: lambda o: o.is_irrational, + Q.even: lambda o: o.is_even, + Q.odd: lambda o: o.is_odd, + Q.imaginary: lambda o: o.is_imaginary, + Q.prime: lambda o: o.is_prime, + Q.composite: lambda o: o.is_composite, +} + +@class_fact_registry.multiregister(Number, NumberSymbol, ImaginaryUnit) +def _(expr): + ret = [] + for p, getter in _old_assump_getters.items(): + pred = p(expr) + prop = getter(expr) + if prop is not None: + ret.append(Equivalent(pred, prop)) + return ret diff --git a/.venv/lib/python3.13/site-packages/sympy/assumptions/wrapper.py b/.venv/lib/python3.13/site-packages/sympy/assumptions/wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..cb06e9de770ed41a2b3d6fe63381ad1cb59acacc --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/assumptions/wrapper.py @@ -0,0 +1,164 @@ +""" +Functions and wrapper object to call assumption property and predicate +query with same syntax. + +In SymPy, there are two assumption systems. Old assumption system is +defined in sympy/core/assumptions, and it can be accessed by attribute +such as ``x.is_even``. New assumption system is defined in +sympy/assumptions, and it can be accessed by predicates such as +``Q.even(x)``. + +Old assumption is fast, while new assumptions can freely take local facts. +In general, old assumption is used in evaluation method and new assumption +is used in refinement method. + +In most cases, both evaluation and refinement follow the same process, and +the only difference is which assumption system is used. This module provides +``is_[...]()`` functions and ``AssumptionsWrapper()`` class which allows +using two systems with same syntax so that parallel code implementation can be +avoided. + +Examples +======== + +For multiple use, use ``AssumptionsWrapper()``. + +>>> from sympy import Q, Symbol +>>> from sympy.assumptions.wrapper import AssumptionsWrapper +>>> x = Symbol('x') +>>> _x = AssumptionsWrapper(x, Q.even(x)) +>>> _x.is_integer +True +>>> _x.is_odd +False + +For single use, use ``is_[...]()`` functions. + +>>> from sympy.assumptions.wrapper import is_infinite +>>> a = Symbol('a') +>>> print(is_infinite(a)) +None +>>> is_infinite(a, Q.finite(a)) +False + +""" + +from sympy.assumptions import ask, Q +from sympy.core.basic import Basic +from sympy.core.sympify import _sympify + + +def make_eval_method(fact): + def getit(self): + pred = getattr(Q, fact) + ret = ask(pred(self.expr), self.assumptions) + return ret + return getit + + +# we subclass Basic to use the fact deduction and caching +class AssumptionsWrapper(Basic): + """ + Wrapper over ``Basic`` instances to call predicate query by + ``.is_[...]`` property + + Parameters + ========== + + expr : Basic + + assumptions : Boolean, optional + + Examples + ======== + + >>> from sympy import Q, Symbol + >>> from sympy.assumptions.wrapper import AssumptionsWrapper + >>> x = Symbol('x', even=True) + >>> AssumptionsWrapper(x).is_integer + True + >>> y = Symbol('y') + >>> AssumptionsWrapper(y, Q.even(y)).is_integer + True + + With ``AssumptionsWrapper``, both evaluation and refinement can be supported + by single implementation. + + >>> from sympy import Function + >>> class MyAbs(Function): + ... @classmethod + ... def eval(cls, x, assumptions=True): + ... _x = AssumptionsWrapper(x, assumptions) + ... if _x.is_nonnegative: + ... return x + ... if _x.is_negative: + ... return -x + ... def _eval_refine(self, assumptions): + ... return MyAbs.eval(self.args[0], assumptions) + >>> MyAbs(x) + MyAbs(x) + >>> MyAbs(x).refine(Q.positive(x)) + x + >>> MyAbs(Symbol('y', negative=True)) + -y + + """ + def __new__(cls, expr, assumptions=None): + if assumptions is None: + return expr + obj = super().__new__(cls, expr, _sympify(assumptions)) + obj.expr = expr + obj.assumptions = assumptions + return obj + + _eval_is_algebraic = make_eval_method("algebraic") + _eval_is_antihermitian = make_eval_method("antihermitian") + _eval_is_commutative = make_eval_method("commutative") + _eval_is_complex = make_eval_method("complex") + _eval_is_composite = make_eval_method("composite") + _eval_is_even = make_eval_method("even") + _eval_is_extended_negative = make_eval_method("extended_negative") + _eval_is_extended_nonnegative = make_eval_method("extended_nonnegative") + _eval_is_extended_nonpositive = make_eval_method("extended_nonpositive") + _eval_is_extended_nonzero = make_eval_method("extended_nonzero") + _eval_is_extended_positive = make_eval_method("extended_positive") + _eval_is_extended_real = make_eval_method("extended_real") + _eval_is_finite = make_eval_method("finite") + _eval_is_hermitian = make_eval_method("hermitian") + _eval_is_imaginary = make_eval_method("imaginary") + _eval_is_infinite = make_eval_method("infinite") + _eval_is_integer = make_eval_method("integer") + _eval_is_irrational = make_eval_method("irrational") + _eval_is_negative = make_eval_method("negative") + _eval_is_noninteger = make_eval_method("noninteger") + _eval_is_nonnegative = make_eval_method("nonnegative") + _eval_is_nonpositive = make_eval_method("nonpositive") + _eval_is_nonzero = make_eval_method("nonzero") + _eval_is_odd = make_eval_method("odd") + _eval_is_polar = make_eval_method("polar") + _eval_is_positive = make_eval_method("positive") + _eval_is_prime = make_eval_method("prime") + _eval_is_rational = make_eval_method("rational") + _eval_is_real = make_eval_method("real") + _eval_is_transcendental = make_eval_method("transcendental") + _eval_is_zero = make_eval_method("zero") + + +# one shot functions which are faster than AssumptionsWrapper + +def is_infinite(obj, assumptions=None): + if assumptions is None: + return obj.is_infinite + return ask(Q.infinite(obj), assumptions) + + +def is_extended_real(obj, assumptions=None): + if assumptions is None: + return obj.is_extended_real + return ask(Q.extended_real(obj), assumptions) + + +def is_extended_nonnegative(obj, assumptions=None): + if assumptions is None: + return obj.is_extended_nonnegative + return ask(Q.extended_nonnegative(obj), assumptions) diff --git a/.venv/lib/python3.13/site-packages/sympy/benchmarks/__init__.py b/.venv/lib/python3.13/site-packages/sympy/benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.13/site-packages/sympy/benchmarks/bench_discrete_log.py b/.venv/lib/python3.13/site-packages/sympy/benchmarks/bench_discrete_log.py new file mode 100644 index 0000000000000000000000000000000000000000..76b273909e415318a7d3bace00ffff2a0bc53762 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/benchmarks/bench_discrete_log.py @@ -0,0 +1,83 @@ +import sys +from time import time +from sympy.ntheory.residue_ntheory import (discrete_log, + _discrete_log_trial_mul, _discrete_log_shanks_steps, + _discrete_log_pollard_rho, _discrete_log_pohlig_hellman) + + +# Cyclic group (Z/pZ)* with p prime, order p - 1 and generator g +data_set_1 = [ + # p, p - 1, g + [191, 190, 19], + [46639, 46638, 6], + [14789363, 14789362, 2], + [4254225211, 4254225210, 2], + [432751500361, 432751500360, 7], + [158505390797053, 158505390797052, 2], + [6575202655312007, 6575202655312006, 5], + [8430573471995353769, 8430573471995353768, 3], + [3938471339744997827267, 3938471339744997827266, 2], + [875260951364705563393093, 875260951364705563393092, 5], + ] + + +# Cyclic sub-groups of (Z/nZ)* with prime order p and generator g +# (n, p are primes and n = 2 * p + 1) +data_set_2 = [ + # n, p, g + [227, 113, 3], + [2447, 1223, 2], + [24527, 12263, 2], + [245639, 122819, 2], + [2456747, 1228373, 3], + [24567899, 12283949, 3], + [245679023, 122839511, 2], + [2456791307, 1228395653, 3], + [24567913439, 12283956719, 2], + [245679135407, 122839567703, 2], + [2456791354763, 1228395677381, 3], + [24567913550903, 12283956775451, 2], + [245679135509519, 122839567754759, 2], + ] + + +# Cyclic sub-groups of (Z/nZ)* with smooth order o and generator g +data_set_3 = [ + # n, o, g + [2**118, 2**116, 3], + ] + + +def bench_discrete_log(data_set, algo=None): + if algo is None: + f = discrete_log + elif algo == 'trial': + f = _discrete_log_trial_mul + elif algo == 'shanks': + f = _discrete_log_shanks_steps + elif algo == 'rho': + f = _discrete_log_pollard_rho + elif algo == 'ph': + f = _discrete_log_pohlig_hellman + else: + raise ValueError("Argument 'algo' should be one" + " of ('trial', 'shanks', 'rho' or 'ph')") + + for i, data in enumerate(data_set): + for j, (n, p, g) in enumerate(data): + t = time() + l = f(n, pow(g, p - 1, n), g, p) + t = time() - t + print('[%02d-%03d] %15.10f' % (i, j, t)) + assert l == p - 1 + + +if __name__ == '__main__': + algo = sys.argv[1] \ + if len(sys.argv) > 1 else None + data_set = [ + data_set_1, + data_set_2, + data_set_3, + ] + bench_discrete_log(data_set, algo) diff --git a/.venv/lib/python3.13/site-packages/sympy/benchmarks/bench_meijerint.py b/.venv/lib/python3.13/site-packages/sympy/benchmarks/bench_meijerint.py new file mode 100644 index 0000000000000000000000000000000000000000..d648c3e02463d5a7ee1dcbe3b22af5cc22fef43d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/benchmarks/bench_meijerint.py @@ -0,0 +1,261 @@ +# conceal the implicit import from the code quality tester +from sympy.core.numbers import (oo, pi) +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.special.bessel import besseli +from sympy.functions.special.gamma_functions import gamma +from sympy.integrals.integrals import integrate +from sympy.integrals.transforms import (mellin_transform, + inverse_fourier_transform, inverse_mellin_transform, + laplace_transform, inverse_laplace_transform, fourier_transform) + +LT = laplace_transform +FT = fourier_transform +MT = mellin_transform +IFT = inverse_fourier_transform +ILT = inverse_laplace_transform +IMT = inverse_mellin_transform + +from sympy.abc import x, y +nu, beta, rho = symbols('nu beta rho') + +apos, bpos, cpos, dpos, posk, p = symbols('a b c d k p', positive=True) +k = Symbol('k', real=True) +negk = Symbol('k', negative=True) + +mu1, mu2 = symbols('mu1 mu2', real=True, nonzero=True, finite=True) +sigma1, sigma2 = symbols('sigma1 sigma2', real=True, nonzero=True, + finite=True, positive=True) +rate = Symbol('lambda', positive=True) + + +def normal(x, mu, sigma): + return 1/sqrt(2*pi*sigma**2)*exp(-(x - mu)**2/2/sigma**2) + + +def exponential(x, rate): + return rate*exp(-rate*x) +alpha, beta = symbols('alpha beta', positive=True) +betadist = x**(alpha - 1)*(1 + x)**(-alpha - beta)*gamma(alpha + beta) \ + /gamma(alpha)/gamma(beta) +kint = Symbol('k', integer=True, positive=True) +chi = 2**(1 - kint/2)*x**(kint - 1)*exp(-x**2/2)/gamma(kint/2) +chisquared = 2**(-k/2)/gamma(k/2)*x**(k/2 - 1)*exp(-x/2) +dagum = apos*p/x*(x/bpos)**(apos*p)/(1 + x**apos/bpos**apos)**(p + 1) +d1, d2 = symbols('d1 d2', positive=True) +f = sqrt(((d1*x)**d1 * d2**d2)/(d1*x + d2)**(d1 + d2))/x \ + /gamma(d1/2)/gamma(d2/2)*gamma((d1 + d2)/2) +nupos, sigmapos = symbols('nu sigma', positive=True) +rice = x/sigmapos**2*exp(-(x**2 + nupos**2)/2/sigmapos**2)*besseli(0, x* + nupos/sigmapos**2) +mu = Symbol('mu', real=True) +laplace = exp(-abs(x - mu)/bpos)/2/bpos + +u = Symbol('u', polar=True) +tpos = Symbol('t', positive=True) + + +def E(expr): + integrate(expr*exponential(x, rate)*normal(y, mu1, sigma1), + (x, 0, oo), (y, -oo, oo), meijerg=True) + integrate(expr*exponential(x, rate)*normal(y, mu1, sigma1), + (y, -oo, oo), (x, 0, oo), meijerg=True) + +bench = [ + 'MT(x**nu*Heaviside(x - 1), x, s)', + 'MT(x**nu*Heaviside(1 - x), x, s)', + 'MT((1-x)**(beta - 1)*Heaviside(1-x), x, s)', + 'MT((x-1)**(beta - 1)*Heaviside(x-1), x, s)', + 'MT((1+x)**(-rho), x, s)', + 'MT(abs(1-x)**(-rho), x, s)', + 'MT((1-x)**(beta-1)*Heaviside(1-x) + a*(x-1)**(beta-1)*Heaviside(x-1), x, s)', + 'MT((x**a-b**a)/(x-b), x, s)', + 'MT((x**a-bpos**a)/(x-bpos), x, s)', + 'MT(exp(-x), x, s)', + 'MT(exp(-1/x), x, s)', + 'MT(log(x)**4*Heaviside(1-x), x, s)', + 'MT(log(x)**3*Heaviside(x-1), x, s)', + 'MT(log(x + 1), x, s)', + 'MT(log(1/x + 1), x, s)', + 'MT(log(abs(1 - x)), x, s)', + 'MT(log(abs(1 - 1/x)), x, s)', + 'MT(log(x)/(x+1), x, s)', + 'MT(log(x)**2/(x+1), x, s)', + 'MT(log(x)/(x+1)**2, x, s)', + 'MT(erf(sqrt(x)), x, s)', + + 'MT(besselj(a, 2*sqrt(x)), x, s)', + 'MT(sin(sqrt(x))*besselj(a, sqrt(x)), x, s)', + 'MT(cos(sqrt(x))*besselj(a, sqrt(x)), x, s)', + 'MT(besselj(a, sqrt(x))**2, x, s)', + 'MT(besselj(a, sqrt(x))*besselj(-a, sqrt(x)), x, s)', + 'MT(besselj(a - 1, sqrt(x))*besselj(a, sqrt(x)), x, s)', + 'MT(besselj(a, sqrt(x))*besselj(b, sqrt(x)), x, s)', + 'MT(besselj(a, sqrt(x))**2 + besselj(-a, sqrt(x))**2, x, s)', + 'MT(bessely(a, 2*sqrt(x)), x, s)', + 'MT(sin(sqrt(x))*bessely(a, sqrt(x)), x, s)', + 'MT(cos(sqrt(x))*bessely(a, sqrt(x)), x, s)', + 'MT(besselj(a, sqrt(x))*bessely(a, sqrt(x)), x, s)', + 'MT(besselj(a, sqrt(x))*bessely(b, sqrt(x)), x, s)', + 'MT(bessely(a, sqrt(x))**2, x, s)', + + 'MT(besselk(a, 2*sqrt(x)), x, s)', + 'MT(besselj(a, 2*sqrt(2*sqrt(x)))*besselk(a, 2*sqrt(2*sqrt(x))), x, s)', + 'MT(besseli(a, sqrt(x))*besselk(a, sqrt(x)), x, s)', + 'MT(besseli(b, sqrt(x))*besselk(a, sqrt(x)), x, s)', + 'MT(exp(-x/2)*besselk(a, x/2), x, s)', + + # later: ILT, IMT + + 'LT((t-apos)**bpos*exp(-cpos*(t-apos))*Heaviside(t-apos), t, s)', + 'LT(t**apos, t, s)', + 'LT(Heaviside(t), t, s)', + 'LT(Heaviside(t - apos), t, s)', + 'LT(1 - exp(-apos*t), t, s)', + 'LT((exp(2*t)-1)*exp(-bpos - t)*Heaviside(t)/2, t, s, noconds=True)', + 'LT(exp(t), t, s)', + 'LT(exp(2*t), t, s)', + 'LT(exp(apos*t), t, s)', + 'LT(log(t/apos), t, s)', + 'LT(erf(t), t, s)', + 'LT(sin(apos*t), t, s)', + 'LT(cos(apos*t), t, s)', + 'LT(exp(-apos*t)*sin(bpos*t), t, s)', + 'LT(exp(-apos*t)*cos(bpos*t), t, s)', + 'LT(besselj(0, t), t, s, noconds=True)', + 'LT(besselj(1, t), t, s, noconds=True)', + + 'FT(Heaviside(1 - abs(2*apos*x)), x, k)', + 'FT(Heaviside(1-abs(apos*x))*(1-abs(apos*x)), x, k)', + 'FT(exp(-apos*x)*Heaviside(x), x, k)', + 'IFT(1/(apos + 2*pi*I*x), x, posk, noconds=False)', + 'IFT(1/(apos + 2*pi*I*x), x, -posk, noconds=False)', + 'IFT(1/(apos + 2*pi*I*x), x, negk)', + 'FT(x*exp(-apos*x)*Heaviside(x), x, k)', + 'FT(exp(-apos*x)*sin(bpos*x)*Heaviside(x), x, k)', + 'FT(exp(-apos*x**2), x, k)', + 'IFT(sqrt(pi/apos)*exp(-(pi*k)**2/apos), k, x)', + 'FT(exp(-apos*abs(x)), x, k)', + + 'integrate(normal(x, mu1, sigma1), (x, -oo, oo), meijerg=True)', + 'integrate(x*normal(x, mu1, sigma1), (x, -oo, oo), meijerg=True)', + 'integrate(x**2*normal(x, mu1, sigma1), (x, -oo, oo), meijerg=True)', + 'integrate(x**3*normal(x, mu1, sigma1), (x, -oo, oo), meijerg=True)', + 'integrate(normal(x, mu1, sigma1)*normal(y, mu2, sigma2),' + ' (x, -oo, oo), (y, -oo, oo), meijerg=True)', + 'integrate(x*normal(x, mu1, sigma1)*normal(y, mu2, sigma2),' + ' (x, -oo, oo), (y, -oo, oo), meijerg=True)', + 'integrate(y*normal(x, mu1, sigma1)*normal(y, mu2, sigma2),' + ' (x, -oo, oo), (y, -oo, oo), meijerg=True)', + 'integrate(x*y*normal(x, mu1, sigma1)*normal(y, mu2, sigma2),' + ' (x, -oo, oo), (y, -oo, oo), meijerg=True)', + 'integrate((x+y+1)*normal(x, mu1, sigma1)*normal(y, mu2, sigma2),' + ' (x, -oo, oo), (y, -oo, oo), meijerg=True)', + 'integrate((x+y-1)*normal(x, mu1, sigma1)*normal(y, mu2, sigma2),' + ' (x, -oo, oo), (y, -oo, oo), meijerg=True)', + 'integrate(x**2*normal(x, mu1, sigma1)*normal(y, mu2, sigma2),' + ' (x, -oo, oo), (y, -oo, oo), meijerg=True)', + 'integrate(y**2*normal(x, mu1, sigma1)*normal(y, mu2, sigma2),' + ' (x, -oo, oo), (y, -oo, oo), meijerg=True)', + 'integrate(exponential(x, rate), (x, 0, oo), meijerg=True)', + 'integrate(x*exponential(x, rate), (x, 0, oo), meijerg=True)', + 'integrate(x**2*exponential(x, rate), (x, 0, oo), meijerg=True)', + 'E(1)', + 'E(x*y)', + 'E(x*y**2)', + 'E((x+y+1)**2)', + 'E(x+y+1)', + 'E((x+y-1)**2)', + 'integrate(betadist, (x, 0, oo), meijerg=True)', + 'integrate(x*betadist, (x, 0, oo), meijerg=True)', + 'integrate(x**2*betadist, (x, 0, oo), meijerg=True)', + 'integrate(chi, (x, 0, oo), meijerg=True)', + 'integrate(x*chi, (x, 0, oo), meijerg=True)', + 'integrate(x**2*chi, (x, 0, oo), meijerg=True)', + 'integrate(chisquared, (x, 0, oo), meijerg=True)', + 'integrate(x*chisquared, (x, 0, oo), meijerg=True)', + 'integrate(x**2*chisquared, (x, 0, oo), meijerg=True)', + 'integrate(((x-k)/sqrt(2*k))**3*chisquared, (x, 0, oo), meijerg=True)', + 'integrate(dagum, (x, 0, oo), meijerg=True)', + 'integrate(x*dagum, (x, 0, oo), meijerg=True)', + 'integrate(x**2*dagum, (x, 0, oo), meijerg=True)', + 'integrate(f, (x, 0, oo), meijerg=True)', + 'integrate(x*f, (x, 0, oo), meijerg=True)', + 'integrate(x**2*f, (x, 0, oo), meijerg=True)', + 'integrate(rice, (x, 0, oo), meijerg=True)', + 'integrate(laplace, (x, -oo, oo), meijerg=True)', + 'integrate(x*laplace, (x, -oo, oo), meijerg=True)', + 'integrate(x**2*laplace, (x, -oo, oo), meijerg=True)', + 'integrate(log(x) * x**(k-1) * exp(-x) / gamma(k), (x, 0, oo))', + + 'integrate(sin(z*x)*(x**2-1)**(-(y+S(1)/2)), (x, 1, oo), meijerg=True)', + 'integrate(besselj(0,x)*besselj(1,x)*exp(-x**2), (x, 0, oo), meijerg=True)', + 'integrate(besselj(0,x)*besselj(1,x)*besselk(0,x), (x, 0, oo), meijerg=True)', + 'integrate(besselj(0,x)*besselj(1,x)*exp(-x**2), (x, 0, oo), meijerg=True)', + 'integrate(besselj(a,x)*besselj(b,x)/x, (x,0,oo), meijerg=True)', + + 'hyperexpand(meijerg((-s - a/2 + 1, -s + a/2 + 1), (-a/2 - S(1)/2, -s + a/2 + S(3)/2), (a/2, -a/2), (-a/2 - S(1)/2, -s + a/2 + S(3)/2), 1))', + "gammasimp(S('2**(2*s)*(-pi*gamma(-a + 1)*gamma(a + 1)*gamma(-a - s + 1)*gamma(-a + s - 1/2)*gamma(a - s + 3/2)*gamma(a + s + 1)/(a*(a + s)) - gamma(-a - 1/2)*gamma(-a + 1)*gamma(a + 1)*gamma(a + 3/2)*gamma(-s + 3/2)*gamma(s - 1/2)*gamma(-a + s + 1)*gamma(a - s + 1)/(a*(-a + s)))*gamma(-2*s + 1)*gamma(s + 1)/(pi*s*gamma(-a - 1/2)*gamma(a + 3/2)*gamma(-s + 1)*gamma(-s + 3/2)*gamma(s - 1/2)*gamma(-a - s + 1)*gamma(-a + s - 1/2)*gamma(a - s + 1)*gamma(a - s + 3/2))'))", + + 'mellin_transform(E1(x), x, s)', + 'inverse_mellin_transform(gamma(s)/s, s, x, (0, oo))', + 'mellin_transform(expint(a, x), x, s)', + 'mellin_transform(Si(x), x, s)', + 'inverse_mellin_transform(-2**s*sqrt(pi)*gamma((s + 1)/2)/(2*s*gamma(-s/2 + 1)), s, x, (-1, 0))', + 'mellin_transform(Ci(sqrt(x)), x, s)', + 'inverse_mellin_transform(-4**s*sqrt(pi)*gamma(s)/(2*s*gamma(-s + S(1)/2)),s, u, (0, 1))', + 'laplace_transform(Ci(x), x, s)', + 'laplace_transform(expint(a, x), x, s)', + 'laplace_transform(expint(1, x), x, s)', + 'laplace_transform(expint(2, x), x, s)', + 'inverse_laplace_transform(-log(1 + s**2)/2/s, s, u)', + 'inverse_laplace_transform(log(s + 1)/s, s, x)', + 'inverse_laplace_transform((s - log(s + 1))/s**2, s, x)', + 'laplace_transform(Chi(x), x, s)', + 'laplace_transform(Shi(x), x, s)', + + 'integrate(exp(-z*x)/x, (x, 1, oo), meijerg=True, conds="none")', + 'integrate(exp(-z*x)/x**2, (x, 1, oo), meijerg=True, conds="none")', + 'integrate(exp(-z*x)/x**3, (x, 1, oo), meijerg=True,conds="none")', + 'integrate(-cos(x)/x, (x, tpos, oo), meijerg=True)', + 'integrate(-sin(x)/x, (x, tpos, oo), meijerg=True)', + 'integrate(sin(x)/x, (x, 0, z), meijerg=True)', + 'integrate(sinh(x)/x, (x, 0, z), meijerg=True)', + 'integrate(exp(-x)/x, x, meijerg=True)', + 'integrate(exp(-x)/x**2, x, meijerg=True)', + 'integrate(cos(u)/u, u, meijerg=True)', + 'integrate(cosh(u)/u, u, meijerg=True)', + 'integrate(expint(1, x), x, meijerg=True)', + 'integrate(expint(2, x), x, meijerg=True)', + 'integrate(Si(x), x, meijerg=True)', + 'integrate(Ci(u), u, meijerg=True)', + 'integrate(Shi(x), x, meijerg=True)', + 'integrate(Chi(u), u, meijerg=True)', + 'integrate(Si(x)*exp(-x), (x, 0, oo), meijerg=True)', + 'integrate(expint(1, x)*sin(x), (x, 0, oo), meijerg=True)' +] + +from time import time +from sympy.core.cache import clear_cache +import sys + +timings = [] + +if __name__ == '__main__': + for n, string in enumerate(bench): + clear_cache() + _t = time() + exec(string) + _t = time() - _t + timings += [(_t, string)] + sys.stdout.write('.') + sys.stdout.flush() + if n % (len(bench) // 10) == 0: + sys.stdout.write('%s' % (10*n // len(bench))) + print() + + timings.sort(key=lambda x: -x[0]) + + for ti, string in timings: + print('%.2fs %s' % (ti, string)) diff --git a/.venv/lib/python3.13/site-packages/sympy/benchmarks/bench_symbench.py b/.venv/lib/python3.13/site-packages/sympy/benchmarks/bench_symbench.py new file mode 100644 index 0000000000000000000000000000000000000000..8ea700b44b677107f5345196a8895e8ed5a9d56d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/benchmarks/bench_symbench.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python +from sympy.core.random import random +from sympy.core.numbers import (I, Integer, pi) +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.polys.polytools import factor +from sympy.simplify.simplify import simplify +from sympy.abc import x, y, z +from timeit import default_timer as clock + + +def bench_R1(): + "real(f(f(f(f(f(f(f(f(f(f(i/2)))))))))))" + def f(z): + return sqrt(Integer(1)/3)*z**2 + I/3 + f(f(f(f(f(f(f(f(f(f(I/2)))))))))).as_real_imag()[0] + + +def bench_R2(): + "Hermite polynomial hermite(15, y)" + def hermite(n, y): + if n == 1: + return 2*y + if n == 0: + return 1 + return (2*y*hermite(n - 1, y) - 2*(n - 1)*hermite(n - 2, y)).expand() + + hermite(15, y) + + +def bench_R3(): + "a = [bool(f==f) for _ in range(10)]" + f = x + y + z + [bool(f == f) for _ in range(10)] + + +def bench_R4(): + # we don't have Tuples + pass + + +def bench_R5(): + "blowup(L, 8); L=uniq(L)" + def blowup(L, n): + for i in range(n): + L.append( (L[i] + L[i + 1]) * L[i + 2] ) + + def uniq(x): + v = set(x) + return v + L = [x, y, z] + blowup(L, 8) + L = uniq(L) + + +def bench_R6(): + "sum(simplify((x+sin(i))/x+(x-sin(i))/x) for i in range(100))" + sum(simplify((x + sin(i))/x + (x - sin(i))/x) for i in range(100)) + + +def bench_R7(): + "[f.subs(x, random()) for _ in range(10**4)]" + f = x**24 + 34*x**12 + 45*x**3 + 9*x**18 + 34*x**10 + 32*x**21 + [f.subs(x, random()) for _ in range(10**4)] + + +def bench_R8(): + "right(x^2,0,5,10^4)" + def right(f, a, b, n): + a = sympify(a) + b = sympify(b) + n = sympify(n) + x = f.atoms(Symbol).pop() + Deltax = (b - a)/n + c = a + est = 0 + for i in range(n): + c += Deltax + est += f.subs(x, c) + return est*Deltax + + right(x**2, 0, 5, 10**4) + + +def _bench_R9(): + "factor(x^20 - pi^5*y^20)" + factor(x**20 - pi**5*y**20) + + +def bench_R10(): + "v = [-pi,-pi+1/10..,pi]" + def srange(min, max, step): + v = [min] + while (max - v[-1]).evalf() > 0: + v.append(v[-1] + step) + return v[:-1] + srange(-pi, pi, sympify(1)/10) + + +def bench_R11(): + "a = [random() + random()*I for w in [0..1000]]" + [random() + random()*I for w in range(1000)] + + +def bench_S1(): + "e=(x+y+z+1)**7;f=e*(e+1);f.expand()" + e = (x + y + z + 1)**7 + f = e*(e + 1) + f.expand() + + +if __name__ == '__main__': + benchmarks = [ + bench_R1, + bench_R2, + bench_R3, + bench_R5, + bench_R6, + bench_R7, + bench_R8, + #_bench_R9, + bench_R10, + bench_R11, + #bench_S1, + ] + + report = [] + for b in benchmarks: + t = clock() + b() + t = clock() - t + print("%s%65s: %f" % (b.__name__, b.__doc__, t)) diff --git a/.venv/lib/python3.13/site-packages/sympy/codegen/__init__.py b/.venv/lib/python3.13/site-packages/sympy/codegen/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..62b195633bae28371bdf9e79317050d7fa7125ae --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/codegen/__init__.py @@ -0,0 +1,24 @@ +""" The ``sympy.codegen`` module contains classes and functions for building +abstract syntax trees of algorithms. These trees may then be printed by the +code-printers in ``sympy.printing``. + +There are several submodules available: +- ``sympy.codegen.ast``: AST nodes useful across multiple languages. +- ``sympy.codegen.cnodes``: AST nodes useful for the C family of languages. +- ``sympy.codegen.fnodes``: AST nodes useful for Fortran. +- ``sympy.codegen.cfunctions``: functions specific to C (C99 math functions) +- ``sympy.codegen.ffunctions``: functions specific to Fortran (e.g. ``kind``). + + + +""" +from .ast import ( + Assignment, aug_assign, CodeBlock, For, Attribute, Variable, Declaration, + While, Scope, Print, FunctionPrototype, FunctionDefinition, FunctionCall +) + +__all__ = [ + 'Assignment', 'aug_assign', 'CodeBlock', 'For', 'Attribute', 'Variable', + 'Declaration', 'While', 'Scope', 'Print', 'FunctionPrototype', + 'FunctionDefinition', 'FunctionCall', +] diff --git a/.venv/lib/python3.13/site-packages/sympy/codegen/abstract_nodes.py b/.venv/lib/python3.13/site-packages/sympy/codegen/abstract_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..ae0a8b3e996a7112edf2568c00138e11c3f3327d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/codegen/abstract_nodes.py @@ -0,0 +1,18 @@ +"""This module provides containers for python objects that are valid +printing targets but are not a subclass of SymPy's Printable. +""" + + +from sympy.core.containers import Tuple + + +class List(Tuple): + """Represents a (frozen) (Python) list (for code printing purposes).""" + def __eq__(self, other): + if isinstance(other, list): + return self == List(*other) + else: + return self.args == other + + def __hash__(self): + return super().__hash__() diff --git a/.venv/lib/python3.13/site-packages/sympy/codegen/algorithms.py b/.venv/lib/python3.13/site-packages/sympy/codegen/algorithms.py new file mode 100644 index 0000000000000000000000000000000000000000..f4890eb8c25e565095600e2713c0de270aa0cf97 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/codegen/algorithms.py @@ -0,0 +1,180 @@ +from sympy.core.containers import Tuple +from sympy.core.numbers import oo +from sympy.core.relational import (Gt, Lt) +from sympy.core.symbol import (Dummy, Symbol) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.miscellaneous import Min, Max +from sympy.logic.boolalg import And +from sympy.codegen.ast import ( + Assignment, AddAugmentedAssignment, break_, CodeBlock, Declaration, FunctionDefinition, + Print, Return, Scope, While, Variable, Pointer, real +) +from sympy.codegen.cfunctions import isnan + +""" This module collects functions for constructing ASTs representing algorithms. """ + +def newtons_method(expr, wrt, atol=1e-12, delta=None, *, rtol=4e-16, debug=False, + itermax=None, counter=None, delta_fn=lambda e, x: -e/e.diff(x), + cse=False, handle_nan=None, + bounds=None): + """ Generates an AST for Newton-Raphson method (a root-finding algorithm). + + Explanation + =========== + + Returns an abstract syntax tree (AST) based on ``sympy.codegen.ast`` for Netwon's + method of root-finding. + + Parameters + ========== + + expr : expression + wrt : Symbol + With respect to, i.e. what is the variable. + atol : number or expression + Absolute tolerance (stopping criterion) + rtol : number or expression + Relative tolerance (stopping criterion) + delta : Symbol + Will be a ``Dummy`` if ``None``. + debug : bool + Whether to print convergence information during iterations + itermax : number or expr + Maximum number of iterations. + counter : Symbol + Will be a ``Dummy`` if ``None``. + delta_fn: Callable[[Expr, Symbol], Expr] + computes the step, default is newtons method. For e.g. Halley's method + use delta_fn=lambda e, x: -2*e*e.diff(x)/(2*e.diff(x)**2 - e*e.diff(x, 2)) + cse: bool + Perform common sub-expression elimination on delta expression + handle_nan: Token + How to handle occurrence of not-a-number (NaN). + bounds: Optional[tuple[Expr, Expr]] + Perform optimization within bounds + + Examples + ======== + + >>> from sympy import symbols, cos + >>> from sympy.codegen.ast import Assignment + >>> from sympy.codegen.algorithms import newtons_method + >>> x, dx, atol = symbols('x dx atol') + >>> expr = cos(x) - x**3 + >>> algo = newtons_method(expr, x, atol=atol, delta=dx) + >>> algo.has(Assignment(dx, -expr/expr.diff(x))) + True + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Newton%27s_method + + """ + + if delta is None: + delta = Dummy() + Wrapper = Scope + name_d = 'delta' + else: + Wrapper = lambda x: x + name_d = delta.name + + delta_expr = delta_fn(expr, wrt) + if cse: + from sympy.simplify.cse_main import cse + cses, (red,) = cse([delta_expr.factor()]) + whl_bdy = [Assignment(dum, sub_e) for dum, sub_e in cses] + whl_bdy += [Assignment(delta, red)] + else: + whl_bdy = [Assignment(delta, delta_expr)] + if handle_nan is not None: + whl_bdy += [While(isnan(delta), CodeBlock(handle_nan, break_))] + whl_bdy += [AddAugmentedAssignment(wrt, delta)] + if bounds is not None: + whl_bdy += [Assignment(wrt, Min(Max(wrt, bounds[0]), bounds[1]))] + if debug: + prnt = Print([wrt, delta], r"{}=%12.5g {}=%12.5g\n".format(wrt.name, name_d)) + whl_bdy += [prnt] + req = Gt(Abs(delta), atol + rtol*Abs(wrt)) + declars = [Declaration(Variable(delta, type=real, value=oo))] + if itermax is not None: + counter = counter or Dummy(integer=True) + v_counter = Variable.deduced(counter, 0) + declars.append(Declaration(v_counter)) + whl_bdy.append(AddAugmentedAssignment(counter, 1)) + req = And(req, Lt(counter, itermax)) + whl = While(req, CodeBlock(*whl_bdy)) + blck = declars + if debug: + blck.append(Print([wrt], r"{}=%12.5g\n".format(wrt.name))) + blck += [whl] + return Wrapper(CodeBlock(*blck)) + + +def _symbol_of(arg): + if isinstance(arg, Declaration): + arg = arg.variable.symbol + elif isinstance(arg, Variable): + arg = arg.symbol + return arg + + +def newtons_method_function(expr, wrt, params=None, func_name="newton", attrs=Tuple(), *, delta=None, **kwargs): + """ Generates an AST for a function implementing the Newton-Raphson method. + + Parameters + ========== + + expr : expression + wrt : Symbol + With respect to, i.e. what is the variable + params : iterable of symbols + Symbols appearing in expr that are taken as constants during the iterations + (these will be accepted as parameters to the generated function). + func_name : str + Name of the generated function. + attrs : Tuple + Attribute instances passed as ``attrs`` to ``FunctionDefinition``. + \\*\\*kwargs : + Keyword arguments passed to :func:`sympy.codegen.algorithms.newtons_method`. + + Examples + ======== + + >>> from sympy import symbols, cos + >>> from sympy.codegen.algorithms import newtons_method_function + >>> from sympy.codegen.pyutils import render_as_module + >>> x = symbols('x') + >>> expr = cos(x) - x**3 + >>> func = newtons_method_function(expr, x) + >>> py_mod = render_as_module(func) # source code as string + >>> namespace = {} + >>> exec(py_mod, namespace, namespace) + >>> res = eval('newton(0.5)', namespace) + >>> abs(res - 0.865474033102) < 1e-12 + True + + See Also + ======== + + sympy.codegen.algorithms.newtons_method + + """ + if params is None: + params = (wrt,) + pointer_subs = {p.symbol: Symbol('(*%s)' % p.symbol.name) + for p in params if isinstance(p, Pointer)} + if delta is None: + delta = Symbol('d_' + wrt.name) + if expr.has(delta): + delta = None # will use Dummy + algo = newtons_method(expr, wrt, delta=delta, **kwargs).xreplace(pointer_subs) + if isinstance(algo, Scope): + algo = algo.body + not_in_params = expr.free_symbols.difference({_symbol_of(p) for p in params}) + if not_in_params: + raise ValueError("Missing symbols in params: %s" % ', '.join(map(str, not_in_params))) + declars = tuple(Variable(p, real) for p in params) + body = CodeBlock(algo, Return(wrt)) + return FunctionDefinition(real, func_name, declars, body, attrs=attrs) diff --git a/.venv/lib/python3.13/site-packages/sympy/codegen/approximations.py b/.venv/lib/python3.13/site-packages/sympy/codegen/approximations.py new file mode 100644 index 0000000000000000000000000000000000000000..c6486926938224c5052dc5adfac29807e82eb4d8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/codegen/approximations.py @@ -0,0 +1,187 @@ +import math +from sympy.sets.sets import Interval +from sympy.calculus.singularities import is_increasing, is_decreasing +from sympy.codegen.rewriting import Optimization +from sympy.core.function import UndefinedFunction + +""" +This module collects classes useful for approximate rewriting of expressions. +This can be beneficial when generating numeric code for which performance is +of greater importance than precision (e.g. for preconditioners used in iterative +methods). +""" + +class SumApprox(Optimization): + """ + Approximates sum by neglecting small terms. + + Explanation + =========== + + If terms are expressions which can be determined to be monotonic, then + bounds for those expressions are added. + + Parameters + ========== + + bounds : dict + Mapping expressions to length 2 tuple of bounds (low, high). + reltol : number + Threshold for when to ignore a term. Taken relative to the largest + lower bound among bounds. + + Examples + ======== + + >>> from sympy import exp + >>> from sympy.abc import x, y, z + >>> from sympy.codegen.rewriting import optimize + >>> from sympy.codegen.approximations import SumApprox + >>> bounds = {x: (-1, 1), y: (1000, 2000), z: (-10, 3)} + >>> sum_approx3 = SumApprox(bounds, reltol=1e-3) + >>> sum_approx2 = SumApprox(bounds, reltol=1e-2) + >>> sum_approx1 = SumApprox(bounds, reltol=1e-1) + >>> expr = 3*(x + y + exp(z)) + >>> optimize(expr, [sum_approx3]) + 3*(x + y + exp(z)) + >>> optimize(expr, [sum_approx2]) + 3*y + 3*exp(z) + >>> optimize(expr, [sum_approx1]) + 3*y + + """ + + def __init__(self, bounds, reltol, **kwargs): + super().__init__(**kwargs) + self.bounds = bounds + self.reltol = reltol + + def __call__(self, expr): + return expr.factor().replace(self.query, lambda arg: self.value(arg)) + + def query(self, expr): + return expr.is_Add + + def value(self, add): + for term in add.args: + if term.is_number or term in self.bounds or len(term.free_symbols) != 1: + continue + fs, = term.free_symbols + if fs not in self.bounds: + continue + intrvl = Interval(*self.bounds[fs]) + if is_increasing(term, intrvl, fs): + self.bounds[term] = ( + term.subs({fs: self.bounds[fs][0]}), + term.subs({fs: self.bounds[fs][1]}) + ) + elif is_decreasing(term, intrvl, fs): + self.bounds[term] = ( + term.subs({fs: self.bounds[fs][1]}), + term.subs({fs: self.bounds[fs][0]}) + ) + else: + return add + + if all(term.is_number or term in self.bounds for term in add.args): + bounds = [(term, term) if term.is_number else self.bounds[term] for term in add.args] + largest_abs_guarantee = 0 + for lo, hi in bounds: + if lo <= 0 <= hi: + continue + largest_abs_guarantee = max(largest_abs_guarantee, + min(abs(lo), abs(hi))) + new_terms = [] + for term, (lo, hi) in zip(add.args, bounds): + if max(abs(lo), abs(hi)) >= largest_abs_guarantee*self.reltol: + new_terms.append(term) + return add.func(*new_terms) + else: + return add + + +class SeriesApprox(Optimization): + """ Approximates functions by expanding them as a series. + + Parameters + ========== + + bounds : dict + Mapping expressions to length 2 tuple of bounds (low, high). + reltol : number + Threshold for when to ignore a term. Taken relative to the largest + lower bound among bounds. + max_order : int + Largest order to include in series expansion + n_point_checks : int (even) + The validity of an expansion (with respect to reltol) is checked at + discrete points (linearly spaced over the bounds of the variable). The + number of points used in this numerical check is given by this number. + + Examples + ======== + + >>> from sympy import sin, pi + >>> from sympy.abc import x, y + >>> from sympy.codegen.rewriting import optimize + >>> from sympy.codegen.approximations import SeriesApprox + >>> bounds = {x: (-.1, .1), y: (pi-1, pi+1)} + >>> series_approx2 = SeriesApprox(bounds, reltol=1e-2) + >>> series_approx3 = SeriesApprox(bounds, reltol=1e-3) + >>> series_approx8 = SeriesApprox(bounds, reltol=1e-8) + >>> expr = sin(x)*sin(y) + >>> optimize(expr, [series_approx2]) + x*(-y + (y - pi)**3/6 + pi) + >>> optimize(expr, [series_approx3]) + (-x**3/6 + x)*sin(y) + >>> optimize(expr, [series_approx8]) + sin(x)*sin(y) + + """ + def __init__(self, bounds, reltol, max_order=4, n_point_checks=4, **kwargs): + super().__init__(**kwargs) + self.bounds = bounds + self.reltol = reltol + self.max_order = max_order + if n_point_checks % 2 == 1: + raise ValueError("Checking the solution at expansion point is not helpful") + self.n_point_checks = n_point_checks + self._prec = math.ceil(-math.log10(self.reltol)) + + def __call__(self, expr): + return expr.factor().replace(self.query, lambda arg: self.value(arg)) + + def query(self, expr): + return (expr.is_Function and not isinstance(expr, UndefinedFunction) + and len(expr.args) == 1) + + def value(self, fexpr): + free_symbols = fexpr.free_symbols + if len(free_symbols) != 1: + return fexpr + symb, = free_symbols + if symb not in self.bounds: + return fexpr + lo, hi = self.bounds[symb] + x0 = (lo + hi)/2 + cheapest = None + for n in range(self.max_order+1, 0, -1): + fseri = fexpr.series(symb, x0=x0, n=n).removeO() + n_ok = True + for idx in range(self.n_point_checks): + x = lo + idx*(hi - lo)/(self.n_point_checks - 1) + val = fseri.xreplace({symb: x}) + ref = fexpr.xreplace({symb: x}) + if abs((1 - val/ref).evalf(self._prec)) > self.reltol: + n_ok = False + break + + if n_ok: + cheapest = fseri + else: + break + + if cheapest is None: + return fexpr + else: + return cheapest diff --git a/.venv/lib/python3.13/site-packages/sympy/codegen/ast.py b/.venv/lib/python3.13/site-packages/sympy/codegen/ast.py new file mode 100644 index 0000000000000000000000000000000000000000..dd774ca87c5c9d4b55c8ea7a3b68837035b0d06d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/codegen/ast.py @@ -0,0 +1,1906 @@ +""" +Types used to represent a full function/module as an Abstract Syntax Tree. + +Most types are small, and are merely used as tokens in the AST. A tree diagram +has been included below to illustrate the relationships between the AST types. + + +AST Type Tree +------------- +:: + + *Basic* + | + | + CodegenAST + | + |--->AssignmentBase + | |--->Assignment + | |--->AugmentedAssignment + | |--->AddAugmentedAssignment + | |--->SubAugmentedAssignment + | |--->MulAugmentedAssignment + | |--->DivAugmentedAssignment + | |--->ModAugmentedAssignment + | + |--->CodeBlock + | + | + |--->Token + |--->Attribute + |--->For + |--->String + | |--->QuotedString + | |--->Comment + |--->Type + | |--->IntBaseType + | | |--->_SizedIntType + | | |--->SignedIntType + | | |--->UnsignedIntType + | |--->FloatBaseType + | |--->FloatType + | |--->ComplexBaseType + | |--->ComplexType + |--->Node + | |--->Variable + | | |---> Pointer + | |--->FunctionPrototype + | |--->FunctionDefinition + |--->Element + |--->Declaration + |--->While + |--->Scope + |--->Stream + |--->Print + |--->FunctionCall + |--->BreakToken + |--->ContinueToken + |--->NoneToken + |--->Return + + +Predefined types +---------------- + +A number of ``Type`` instances are provided in the ``sympy.codegen.ast`` module +for convenience. Perhaps the two most common ones for code-generation (of numeric +codes) are ``float32`` and ``float64`` (known as single and double precision respectively). +There are also precision generic versions of Types (for which the codeprinters selects the +underlying data type at time of printing): ``real``, ``integer``, ``complex_``, ``bool_``. + +The other ``Type`` instances defined are: + +- ``intc``: Integer type used by C's "int". +- ``intp``: Integer type used by C's "unsigned". +- ``int8``, ``int16``, ``int32``, ``int64``: n-bit integers. +- ``uint8``, ``uint16``, ``uint32``, ``uint64``: n-bit unsigned integers. +- ``float80``: known as "extended precision" on modern x86/amd64 hardware. +- ``complex64``: Complex number represented by two ``float32`` numbers +- ``complex128``: Complex number represented by two ``float64`` numbers + +Using the nodes +--------------- + +It is possible to construct simple algorithms using the AST nodes. Let's construct a loop applying +Newton's method:: + + >>> from sympy import symbols, cos + >>> from sympy.codegen.ast import While, Assignment, aug_assign, Print, QuotedString + >>> t, dx, x = symbols('tol delta val') + >>> expr = cos(x) - x**3 + >>> whl = While(abs(dx) > t, [ + ... Assignment(dx, -expr/expr.diff(x)), + ... aug_assign(x, '+', dx), + ... Print([x]) + ... ]) + >>> from sympy import pycode + >>> py_str = pycode(whl) + >>> print(py_str) + while (abs(delta) > tol): + delta = (val**3 - math.cos(val))/(-3*val**2 - math.sin(val)) + val += delta + print(val) + >>> import math + >>> tol, val, delta = 1e-5, 0.5, float('inf') + >>> exec(py_str) + 1.1121416371 + 0.909672693737 + 0.867263818209 + 0.865477135298 + 0.865474033111 + >>> print('%3.1g' % (math.cos(val) - val**3)) + -3e-11 + +If we want to generate Fortran code for the same while loop we simple call ``fcode``:: + + >>> from sympy import fcode + >>> print(fcode(whl, standard=2003, source_format='free')) + do while (abs(delta) > tol) + delta = (val**3 - cos(val))/(-3*val**2 - sin(val)) + val = val + delta + print *, val + end do + +There is a function constructing a loop (or a complete function) like this in +:mod:`sympy.codegen.algorithms`. + +""" + +from __future__ import annotations +from typing import Any + +from collections import defaultdict + +from sympy.core.relational import (Ge, Gt, Le, Lt) +from sympy.core import Symbol, Tuple, Dummy +from sympy.core.basic import Basic +from sympy.core.expr import Expr, Atom +from sympy.core.numbers import Float, Integer, oo +from sympy.core.sympify import _sympify, sympify, SympifyError +from sympy.utilities.iterables import (iterable, topological_sort, + numbered_symbols, filter_symbols) + + +def _mk_Tuple(args): + """ + Create a SymPy Tuple object from an iterable, converting Python strings to + AST strings. + + Parameters + ========== + + args: iterable + Arguments to :class:`sympy.Tuple`. + + Returns + ======= + + sympy.Tuple + """ + args = [String(arg) if isinstance(arg, str) else arg for arg in args] + return Tuple(*args) + + +class CodegenAST(Basic): + __slots__ = () + + +class Token(CodegenAST): + """ Base class for the AST types. + + Explanation + =========== + + Defining fields are set in ``_fields``. Attributes (defined in _fields) + are only allowed to contain instances of Basic (unless atomic, see + ``String``). The arguments to ``__new__()`` correspond to the attributes in + the order defined in ``_fields`. The ``defaults`` class attribute is a + dictionary mapping attribute names to their default values. + + Subclasses should not need to override the ``__new__()`` method. They may + define a class or static method named ``_construct_`` for each + attribute to process the value passed to ``__new__()``. Attributes listed + in the class attribute ``not_in_args`` are not passed to :class:`~.Basic`. + """ + + __slots__: tuple[str, ...] = () + _fields = __slots__ + defaults: dict[str, Any] = {} + not_in_args: list[str] = [] + indented_args = ['body'] + + @property + def is_Atom(self): + return len(self._fields) == 0 + + @classmethod + def _get_constructor(cls, attr): + """ Get the constructor function for an attribute by name. """ + return getattr(cls, '_construct_%s' % attr, lambda x: x) + + @classmethod + def _construct(cls, attr, arg): + """ Construct an attribute value from argument passed to ``__new__()``. """ + # arg may be ``NoneToken()``, so comparison is done using == instead of ``is`` operator + if arg == None: + return cls.defaults.get(attr, none) + else: + if isinstance(arg, Dummy): # SymPy's replace uses Dummy instances + return arg + else: + return cls._get_constructor(attr)(arg) + + def __new__(cls, *args, **kwargs): + # Pass through existing instances when given as sole argument + if len(args) == 1 and not kwargs and isinstance(args[0], cls): + return args[0] + + if len(args) > len(cls._fields): + raise ValueError("Too many arguments (%d), expected at most %d" % (len(args), len(cls._fields))) + + attrvals = [] + + # Process positional arguments + for attrname, argval in zip(cls._fields, args): + if attrname in kwargs: + raise TypeError('Got multiple values for attribute %r' % attrname) + + attrvals.append(cls._construct(attrname, argval)) + + # Process keyword arguments + for attrname in cls._fields[len(args):]: + if attrname in kwargs: + argval = kwargs.pop(attrname) + + elif attrname in cls.defaults: + argval = cls.defaults[attrname] + + else: + raise TypeError('No value for %r given and attribute has no default' % attrname) + + attrvals.append(cls._construct(attrname, argval)) + + if kwargs: + raise ValueError("Unknown keyword arguments: %s" % ' '.join(kwargs)) + + # Parent constructor + basic_args = [ + val for attr, val in zip(cls._fields, attrvals) + if attr not in cls.not_in_args + ] + obj = CodegenAST.__new__(cls, *basic_args) + + # Set attributes + for attr, arg in zip(cls._fields, attrvals): + setattr(obj, attr, arg) + + return obj + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + for attr in self._fields: + if getattr(self, attr) != getattr(other, attr): + return False + return True + + def _hashable_content(self): + return tuple([getattr(self, attr) for attr in self._fields]) + + def __hash__(self): + return super().__hash__() + + def _joiner(self, k, indent_level): + return (',\n' + ' '*indent_level) if k in self.indented_args else ', ' + + def _indented(self, printer, k, v, *args, **kwargs): + il = printer._context['indent_level'] + def _print(arg): + if isinstance(arg, Token): + return printer._print(arg, *args, joiner=self._joiner(k, il), **kwargs) + else: + return printer._print(arg, *args, **kwargs) + + if isinstance(v, Tuple): + joined = self._joiner(k, il).join([_print(arg) for arg in v.args]) + if k in self.indented_args: + return '(\n' + ' '*il + joined + ',\n' + ' '*(il - 4) + ')' + else: + return ('({0},)' if len(v.args) == 1 else '({0})').format(joined) + else: + return _print(v) + + def _sympyrepr(self, printer, *args, joiner=', ', **kwargs): + from sympy.printing.printer import printer_context + exclude = kwargs.get('exclude', ()) + values = [getattr(self, k) for k in self._fields] + indent_level = printer._context.get('indent_level', 0) + + arg_reprs = [] + + for i, (attr, value) in enumerate(zip(self._fields, values)): + if attr in exclude: + continue + + # Skip attributes which have the default value + if attr in self.defaults and value == self.defaults[attr]: + continue + + ilvl = indent_level + 4 if attr in self.indented_args else 0 + with printer_context(printer, indent_level=ilvl): + indented = self._indented(printer, attr, value, *args, **kwargs) + arg_reprs.append(('{1}' if i == 0 else '{0}={1}').format(attr, indented.lstrip())) + + return "{}({})".format(self.__class__.__name__, joiner.join(arg_reprs)) + + _sympystr = _sympyrepr + + def __repr__(self): # sympy.core.Basic.__repr__ uses sstr + from sympy.printing import srepr + return srepr(self) + + def kwargs(self, exclude=(), apply=None): + """ Get instance's attributes as dict of keyword arguments. + + Parameters + ========== + + exclude : collection of str + Collection of keywords to exclude. + + apply : callable, optional + Function to apply to all values. + """ + kwargs = {k: getattr(self, k) for k in self._fields if k not in exclude} + if apply is not None: + return {k: apply(v) for k, v in kwargs.items()} + else: + return kwargs + +class BreakToken(Token): + """ Represents 'break' in C/Python ('exit' in Fortran). + + Use the premade instance ``break_`` or instantiate manually. + + Examples + ======== + + >>> from sympy import ccode, fcode + >>> from sympy.codegen.ast import break_ + >>> ccode(break_) + 'break' + >>> fcode(break_, source_format='free') + 'exit' + """ + +break_ = BreakToken() + + +class ContinueToken(Token): + """ Represents 'continue' in C/Python ('cycle' in Fortran) + + Use the premade instance ``continue_`` or instantiate manually. + + Examples + ======== + + >>> from sympy import ccode, fcode + >>> from sympy.codegen.ast import continue_ + >>> ccode(continue_) + 'continue' + >>> fcode(continue_, source_format='free') + 'cycle' + """ + +continue_ = ContinueToken() + +class NoneToken(Token): + """ The AST equivalence of Python's NoneType + + The corresponding instance of Python's ``None`` is ``none``. + + Examples + ======== + + >>> from sympy.codegen.ast import none, Variable + >>> from sympy import pycode + >>> print(pycode(Variable('x').as_Declaration(value=none))) + x = None + + """ + def __eq__(self, other): + return other is None or isinstance(other, NoneToken) + + def _hashable_content(self): + return () + + def __hash__(self): + return super().__hash__() + + +none = NoneToken() + + +class AssignmentBase(CodegenAST): + """ Abstract base class for Assignment and AugmentedAssignment. + + Attributes: + =========== + + op : str + Symbol for assignment operator, e.g. "=", "+=", etc. + """ + + def __new__(cls, lhs, rhs): + lhs = _sympify(lhs) + rhs = _sympify(rhs) + + cls._check_args(lhs, rhs) + + return super().__new__(cls, lhs, rhs) + + @property + def lhs(self): + return self.args[0] + + @property + def rhs(self): + return self.args[1] + + @classmethod + def _check_args(cls, lhs, rhs): + """ Check arguments to __new__ and raise exception if any problems found. + + Derived classes may wish to override this. + """ + from sympy.matrices.expressions.matexpr import ( + MatrixElement, MatrixSymbol) + from sympy.tensor.indexed import Indexed + from sympy.tensor.array.expressions import ArrayElement + + # Tuple of things that can be on the lhs of an assignment + assignable = (Symbol, MatrixSymbol, MatrixElement, Indexed, Element, Variable, + ArrayElement) + if not isinstance(lhs, assignable): + raise TypeError("Cannot assign to lhs of type %s." % type(lhs)) + + # Indexed types implement shape, but don't define it until later. This + # causes issues in assignment validation. For now, matrices are defined + # as anything with a shape that is not an Indexed + lhs_is_mat = hasattr(lhs, 'shape') and not isinstance(lhs, Indexed) + rhs_is_mat = hasattr(rhs, 'shape') and not isinstance(rhs, Indexed) + + # If lhs and rhs have same structure, then this assignment is ok + if lhs_is_mat: + if not rhs_is_mat: + raise ValueError("Cannot assign a scalar to a matrix.") + elif lhs.shape != rhs.shape: + raise ValueError("Dimensions of lhs and rhs do not align.") + elif rhs_is_mat and not lhs_is_mat: + raise ValueError("Cannot assign a matrix to a scalar.") + + +class Assignment(AssignmentBase): + """ + Represents variable assignment for code generation. + + Parameters + ========== + + lhs : Expr + SymPy object representing the lhs of the expression. These should be + singular objects, such as one would use in writing code. Notable types + include Symbol, MatrixSymbol, MatrixElement, and Indexed. Types that + subclass these types are also supported. + + rhs : Expr + SymPy object representing the rhs of the expression. This can be any + type, provided its shape corresponds to that of the lhs. For example, + a Matrix type can be assigned to MatrixSymbol, but not to Symbol, as + the dimensions will not align. + + Examples + ======== + + >>> from sympy import symbols, MatrixSymbol, Matrix + >>> from sympy.codegen.ast import Assignment + >>> x, y, z = symbols('x, y, z') + >>> Assignment(x, y) + Assignment(x, y) + >>> Assignment(x, 0) + Assignment(x, 0) + >>> A = MatrixSymbol('A', 1, 3) + >>> mat = Matrix([x, y, z]).T + >>> Assignment(A, mat) + Assignment(A, Matrix([[x, y, z]])) + >>> Assignment(A[0, 1], x) + Assignment(A[0, 1], x) + """ + + op = ':=' + + +class AugmentedAssignment(AssignmentBase): + """ + Base class for augmented assignments. + + Attributes: + =========== + + binop : str + Symbol for binary operation being applied in the assignment, such as "+", + "*", etc. + """ + binop: str | None + + @property + def op(self): + return self.binop + '=' + + +class AddAugmentedAssignment(AugmentedAssignment): + binop = '+' + + +class SubAugmentedAssignment(AugmentedAssignment): + binop = '-' + + +class MulAugmentedAssignment(AugmentedAssignment): + binop = '*' + + +class DivAugmentedAssignment(AugmentedAssignment): + binop = '/' + + +class ModAugmentedAssignment(AugmentedAssignment): + binop = '%' + + +# Mapping from binary op strings to AugmentedAssignment subclasses +augassign_classes = { + cls.binop: cls for cls in [ + AddAugmentedAssignment, SubAugmentedAssignment, MulAugmentedAssignment, + DivAugmentedAssignment, ModAugmentedAssignment + ] +} + + +def aug_assign(lhs, op, rhs): + """ + Create 'lhs op= rhs'. + + Explanation + =========== + + Represents augmented variable assignment for code generation. This is a + convenience function. You can also use the AugmentedAssignment classes + directly, like AddAugmentedAssignment(x, y). + + Parameters + ========== + + lhs : Expr + SymPy object representing the lhs of the expression. These should be + singular objects, such as one would use in writing code. Notable types + include Symbol, MatrixSymbol, MatrixElement, and Indexed. Types that + subclass these types are also supported. + + op : str + Operator (+, -, /, \\*, %). + + rhs : Expr + SymPy object representing the rhs of the expression. This can be any + type, provided its shape corresponds to that of the lhs. For example, + a Matrix type can be assigned to MatrixSymbol, but not to Symbol, as + the dimensions will not align. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.codegen.ast import aug_assign + >>> x, y = symbols('x, y') + >>> aug_assign(x, '+', y) + AddAugmentedAssignment(x, y) + """ + if op not in augassign_classes: + raise ValueError("Unrecognized operator %s" % op) + return augassign_classes[op](lhs, rhs) + + +class CodeBlock(CodegenAST): + """ + Represents a block of code. + + Explanation + =========== + + For now only assignments are supported. This restriction will be lifted in + the future. + + Useful attributes on this object are: + + ``left_hand_sides``: + Tuple of left-hand sides of assignments, in order. + ``left_hand_sides``: + Tuple of right-hand sides of assignments, in order. + ``free_symbols``: Free symbols of the expressions in the right-hand sides + which do not appear in the left-hand side of an assignment. + + Useful methods on this object are: + + ``topological_sort``: + Class method. Return a CodeBlock with assignments + sorted so that variables are assigned before they + are used. + ``cse``: + Return a new CodeBlock with common subexpressions eliminated and + pulled out as assignments. + + Examples + ======== + + >>> from sympy import symbols, ccode + >>> from sympy.codegen.ast import CodeBlock, Assignment + >>> x, y = symbols('x y') + >>> c = CodeBlock(Assignment(x, 1), Assignment(y, x + 1)) + >>> print(ccode(c)) + x = 1; + y = x + 1; + + """ + def __new__(cls, *args): + left_hand_sides = [] + right_hand_sides = [] + for i in args: + if isinstance(i, Assignment): + lhs, rhs = i.args + left_hand_sides.append(lhs) + right_hand_sides.append(rhs) + + obj = CodegenAST.__new__(cls, *args) + + obj.left_hand_sides = Tuple(*left_hand_sides) + obj.right_hand_sides = Tuple(*right_hand_sides) + return obj + + def __iter__(self): + return iter(self.args) + + def _sympyrepr(self, printer, *args, **kwargs): + il = printer._context.get('indent_level', 0) + joiner = ',\n' + ' '*il + joined = joiner.join(map(printer._print, self.args)) + return ('{}(\n'.format(' '*(il-4) + self.__class__.__name__,) + + ' '*il + joined + '\n' + ' '*(il - 4) + ')') + + _sympystr = _sympyrepr + + @property + def free_symbols(self): + return super().free_symbols - set(self.left_hand_sides) + + @classmethod + def topological_sort(cls, assignments): + """ + Return a CodeBlock with topologically sorted assignments so that + variables are assigned before they are used. + + Examples + ======== + + The existing order of assignments is preserved as much as possible. + + This function assumes that variables are assigned to only once. + + This is a class constructor so that the default constructor for + CodeBlock can error when variables are used before they are assigned. + + >>> from sympy import symbols + >>> from sympy.codegen.ast import CodeBlock, Assignment + >>> x, y, z = symbols('x y z') + + >>> assignments = [ + ... Assignment(x, y + z), + ... Assignment(y, z + 1), + ... Assignment(z, 2), + ... ] + >>> CodeBlock.topological_sort(assignments) + CodeBlock( + Assignment(z, 2), + Assignment(y, z + 1), + Assignment(x, y + z) + ) + + """ + + if not all(isinstance(i, Assignment) for i in assignments): + # Will support more things later + raise NotImplementedError("CodeBlock.topological_sort only supports Assignments") + + if any(isinstance(i, AugmentedAssignment) for i in assignments): + raise NotImplementedError("CodeBlock.topological_sort does not yet work with AugmentedAssignments") + + # Create a graph where the nodes are assignments and there is a directed edge + # between nodes that use a variable and nodes that assign that + # variable, like + + # [(x := 1, y := x + 1), (x := 1, z := y + z), (y := x + 1, z := y + z)] + + # If we then topologically sort these nodes, they will be in + # assignment order, like + + # x := 1 + # y := x + 1 + # z := y + z + + # A = The nodes + # + # enumerate keeps nodes in the same order they are already in if + # possible. It will also allow us to handle duplicate assignments to + # the same variable when those are implemented. + A = list(enumerate(assignments)) + + # var_map = {variable: [nodes for which this variable is assigned to]} + # like {x: [(1, x := y + z), (4, x := 2 * w)], ...} + var_map = defaultdict(list) + for node in A: + i, a = node + var_map[a.lhs].append(node) + + # E = Edges in the graph + E = [] + for dst_node in A: + i, a = dst_node + for s in a.rhs.free_symbols: + for src_node in var_map[s]: + E.append((src_node, dst_node)) + + ordered_assignments = topological_sort([A, E]) + + # De-enumerate the result + return cls(*[a for i, a in ordered_assignments]) + + def cse(self, symbols=None, optimizations=None, postprocess=None, + order='canonical'): + """ + Return a new code block with common subexpressions eliminated. + + Explanation + =========== + + See the docstring of :func:`sympy.simplify.cse_main.cse` for more + information. + + Examples + ======== + + >>> from sympy import symbols, sin + >>> from sympy.codegen.ast import CodeBlock, Assignment + >>> x, y, z = symbols('x y z') + + >>> c = CodeBlock( + ... Assignment(x, 1), + ... Assignment(y, sin(x) + 1), + ... Assignment(z, sin(x) - 1), + ... ) + ... + >>> c.cse() + CodeBlock( + Assignment(x, 1), + Assignment(x0, sin(x)), + Assignment(y, x0 + 1), + Assignment(z, x0 - 1) + ) + + """ + from sympy.simplify.cse_main import cse + + # Check that the CodeBlock only contains assignments to unique variables + if not all(isinstance(i, Assignment) for i in self.args): + # Will support more things later + raise NotImplementedError("CodeBlock.cse only supports Assignments") + + if any(isinstance(i, AugmentedAssignment) for i in self.args): + raise NotImplementedError("CodeBlock.cse does not yet work with AugmentedAssignments") + + for i, lhs in enumerate(self.left_hand_sides): + if lhs in self.left_hand_sides[:i]: + raise NotImplementedError("Duplicate assignments to the same " + "variable are not yet supported (%s)" % lhs) + + # Ensure new symbols for subexpressions do not conflict with existing + existing_symbols = self.atoms(Symbol) + if symbols is None: + symbols = numbered_symbols() + symbols = filter_symbols(symbols, existing_symbols) + + replacements, reduced_exprs = cse(list(self.right_hand_sides), + symbols=symbols, optimizations=optimizations, postprocess=postprocess, + order=order) + + new_block = [Assignment(var, expr) for var, expr in + zip(self.left_hand_sides, reduced_exprs)] + new_assignments = [Assignment(var, expr) for var, expr in replacements] + return self.topological_sort(new_assignments + new_block) + + +class For(Token): + """Represents a 'for-loop' in the code. + + Expressions are of the form: + "for target in iter: + body..." + + Parameters + ========== + + target : symbol + iter : iterable + body : CodeBlock or iterable +! When passed an iterable it is used to instantiate a CodeBlock. + + Examples + ======== + + >>> from sympy import symbols, Range + >>> from sympy.codegen.ast import aug_assign, For + >>> x, i, j, k = symbols('x i j k') + >>> for_i = For(i, Range(10), [aug_assign(x, '+', i*j*k)]) + >>> for_i # doctest: -NORMALIZE_WHITESPACE + For(i, iterable=Range(0, 10, 1), body=CodeBlock( + AddAugmentedAssignment(x, i*j*k) + )) + >>> for_ji = For(j, Range(7), [for_i]) + >>> for_ji # doctest: -NORMALIZE_WHITESPACE + For(j, iterable=Range(0, 7, 1), body=CodeBlock( + For(i, iterable=Range(0, 10, 1), body=CodeBlock( + AddAugmentedAssignment(x, i*j*k) + )) + )) + >>> for_kji =For(k, Range(5), [for_ji]) + >>> for_kji # doctest: -NORMALIZE_WHITESPACE + For(k, iterable=Range(0, 5, 1), body=CodeBlock( + For(j, iterable=Range(0, 7, 1), body=CodeBlock( + For(i, iterable=Range(0, 10, 1), body=CodeBlock( + AddAugmentedAssignment(x, i*j*k) + )) + )) + )) + """ + __slots__ = _fields = ('target', 'iterable', 'body') + _construct_target = staticmethod(_sympify) + + @classmethod + def _construct_body(cls, itr): + if isinstance(itr, CodeBlock): + return itr + else: + return CodeBlock(*itr) + + @classmethod + def _construct_iterable(cls, itr): + if not iterable(itr): + raise TypeError("iterable must be an iterable") + if isinstance(itr, list): # _sympify errors on lists because they are mutable + itr = tuple(itr) + return _sympify(itr) + + +class String(Atom, Token): + """ SymPy object representing a string. + + Atomic object which is not an expression (as opposed to Symbol). + + Parameters + ========== + + text : str + + Examples + ======== + + >>> from sympy.codegen.ast import String + >>> f = String('foo') + >>> f + foo + >>> str(f) + 'foo' + >>> f.text + 'foo' + >>> print(repr(f)) + String('foo') + + """ + __slots__ = _fields = ('text',) + not_in_args = ['text'] + is_Atom = True + + @classmethod + def _construct_text(cls, text): + if not isinstance(text, str): + raise TypeError("Argument text is not a string type.") + return text + + def _sympystr(self, printer, *args, **kwargs): + return self.text + + def kwargs(self, exclude = (), apply = None): + return {} + + #to be removed when Atom is given a suitable func + @property + def func(self): + return lambda: self + + def _latex(self, printer): + from sympy.printing.latex import latex_escape + return r'\texttt{{"{}"}}'.format(latex_escape(self.text)) + +class QuotedString(String): + """ Represents a string which should be printed with quotes. """ + +class Comment(String): + """ Represents a comment. """ + +class Node(Token): + """ Subclass of Token, carrying the attribute 'attrs' (Tuple) + + Examples + ======== + + >>> from sympy.codegen.ast import Node, value_const, pointer_const + >>> n1 = Node([value_const]) + >>> n1.attr_params('value_const') # get the parameters of attribute (by name) + () + >>> from sympy.codegen.fnodes import dimension + >>> n2 = Node([value_const, dimension(5, 3)]) + >>> n2.attr_params(value_const) # get the parameters of attribute (by Attribute instance) + () + >>> n2.attr_params('dimension') # get the parameters of attribute (by name) + (5, 3) + >>> n2.attr_params(pointer_const) is None + True + + """ + + __slots__: tuple[str, ...] = ('attrs',) + _fields = __slots__ + + defaults: dict[str, Any] = {'attrs': Tuple()} + + _construct_attrs = staticmethod(_mk_Tuple) + + def attr_params(self, looking_for): + """ Returns the parameters of the Attribute with name ``looking_for`` in self.attrs """ + for attr in self.attrs: + if str(attr.name) == str(looking_for): + return attr.parameters + + +class Type(Token): + """ Represents a type. + + Explanation + =========== + + The naming is a super-set of NumPy naming. Type has a classmethod + ``from_expr`` which offer type deduction. It also has a method + ``cast_check`` which casts the argument to its type, possibly raising an + exception if rounding error is not within tolerances, or if the value is not + representable by the underlying data type (e.g. unsigned integers). + + Parameters + ========== + + name : str + Name of the type, e.g. ``object``, ``int16``, ``float16`` (where the latter two + would use the ``Type`` sub-classes ``IntType`` and ``FloatType`` respectively). + If a ``Type`` instance is given, the said instance is returned. + + Examples + ======== + + >>> from sympy.codegen.ast import Type + >>> t = Type.from_expr(42) + >>> t + integer + >>> print(repr(t)) + IntBaseType(String('integer')) + >>> from sympy.codegen.ast import uint8 + >>> uint8.cast_check(-1) # doctest: +ELLIPSIS + Traceback (most recent call last): + ... + ValueError: Minimum value for data type bigger than new value. + >>> from sympy.codegen.ast import float32 + >>> v6 = 0.123456 + >>> float32.cast_check(v6) + 0.123456 + >>> v10 = 12345.67894 + >>> float32.cast_check(v10) # doctest: +ELLIPSIS + Traceback (most recent call last): + ... + ValueError: Casting gives a significantly different value. + >>> boost_mp50 = Type('boost::multiprecision::cpp_dec_float_50') + >>> from sympy import cxxcode + >>> from sympy.codegen.ast import Declaration, Variable + >>> cxxcode(Declaration(Variable('x', type=boost_mp50))) + 'boost::multiprecision::cpp_dec_float_50 x' + + References + ========== + + .. [1] https://numpy.org/doc/stable/user/basics.types.html + + """ + __slots__: tuple[str, ...] = ('name',) + _fields = __slots__ + + _construct_name = String + + def _sympystr(self, printer, *args, **kwargs): + return str(self.name) + + @classmethod + def from_expr(cls, expr): + """ Deduces type from an expression or a ``Symbol``. + + Parameters + ========== + + expr : number or SymPy object + The type will be deduced from type or properties. + + Examples + ======== + + >>> from sympy.codegen.ast import Type, integer, complex_ + >>> Type.from_expr(2) == integer + True + >>> from sympy import Symbol + >>> Type.from_expr(Symbol('z', complex=True)) == complex_ + True + >>> Type.from_expr(sum) # doctest: +ELLIPSIS + Traceback (most recent call last): + ... + ValueError: Could not deduce type from expr. + + Raises + ====== + + ValueError when type deduction fails. + + """ + if isinstance(expr, (float, Float)): + return real + if isinstance(expr, (int, Integer)) or getattr(expr, 'is_integer', False): + return integer + if getattr(expr, 'is_real', False): + return real + if isinstance(expr, complex) or getattr(expr, 'is_complex', False): + return complex_ + if isinstance(expr, bool) or getattr(expr, 'is_Relational', False): + return bool_ + else: + raise ValueError("Could not deduce type from expr.") + + def _check(self, value): + pass + + def cast_check(self, value, rtol=None, atol=0, precision_targets=None): + """ Casts a value to the data type of the instance. + + Parameters + ========== + + value : number + rtol : floating point number + Relative tolerance. (will be deduced if not given). + atol : floating point number + Absolute tolerance (in addition to ``rtol``). + type_aliases : dict + Maps substitutions for Type, e.g. {integer: int64, real: float32} + + Examples + ======== + + >>> from sympy.codegen.ast import integer, float32, int8 + >>> integer.cast_check(3.0) == 3 + True + >>> float32.cast_check(1e-40) # doctest: +ELLIPSIS + Traceback (most recent call last): + ... + ValueError: Minimum value for data type bigger than new value. + >>> int8.cast_check(256) # doctest: +ELLIPSIS + Traceback (most recent call last): + ... + ValueError: Maximum value for data type smaller than new value. + >>> v10 = 12345.67894 + >>> float32.cast_check(v10) # doctest: +ELLIPSIS + Traceback (most recent call last): + ... + ValueError: Casting gives a significantly different value. + >>> from sympy.codegen.ast import float64 + >>> float64.cast_check(v10) + 12345.67894 + >>> from sympy import Float + >>> v18 = Float('0.123456789012345646') + >>> float64.cast_check(v18) + Traceback (most recent call last): + ... + ValueError: Casting gives a significantly different value. + >>> from sympy.codegen.ast import float80 + >>> float80.cast_check(v18) + 0.123456789012345649 + + """ + val = sympify(value) + + ten = Integer(10) + exp10 = getattr(self, 'decimal_dig', None) + + if rtol is None: + rtol = 1e-15 if exp10 is None else 2.0*ten**(-exp10) + + def tol(num): + return atol + rtol*abs(num) + + new_val = self.cast_nocheck(value) + self._check(new_val) + + delta = new_val - val + if abs(delta) > tol(val): # rounding, e.g. int(3.5) != 3.5 + raise ValueError("Casting gives a significantly different value.") + + return new_val + + def _latex(self, printer): + from sympy.printing.latex import latex_escape + type_name = latex_escape(self.__class__.__name__) + name = latex_escape(self.name.text) + return r"\text{{{}}}\left(\texttt{{{}}}\right)".format(type_name, name) + + +class IntBaseType(Type): + """ Integer base type, contains no size information. """ + __slots__ = () + cast_nocheck = lambda self, i: Integer(int(i)) + + +class _SizedIntType(IntBaseType): + __slots__ = ('nbits',) + _fields = Type._fields + __slots__ + + _construct_nbits = Integer + + def _check(self, value): + if value < self.min: + raise ValueError("Value is too small: %d < %d" % (value, self.min)) + if value > self.max: + raise ValueError("Value is too big: %d > %d" % (value, self.max)) + + +class SignedIntType(_SizedIntType): + """ Represents a signed integer type. """ + __slots__ = () + @property + def min(self): + return -2**(self.nbits-1) + + @property + def max(self): + return 2**(self.nbits-1) - 1 + + +class UnsignedIntType(_SizedIntType): + """ Represents an unsigned integer type. """ + __slots__ = () + @property + def min(self): + return 0 + + @property + def max(self): + return 2**self.nbits - 1 + +two = Integer(2) + +class FloatBaseType(Type): + """ Represents a floating point number type. """ + __slots__ = () + cast_nocheck = Float + +class FloatType(FloatBaseType): + """ Represents a floating point type with fixed bit width. + + Base 2 & one sign bit is assumed. + + Parameters + ========== + + name : str + Name of the type. + nbits : integer + Number of bits used (storage). + nmant : integer + Number of bits used to represent the mantissa. + nexp : integer + Number of bits used to represent the mantissa. + + Examples + ======== + + >>> from sympy import S + >>> from sympy.codegen.ast import FloatType + >>> half_precision = FloatType('f16', nbits=16, nmant=10, nexp=5) + >>> half_precision.max + 65504 + >>> half_precision.tiny == S(2)**-14 + True + >>> half_precision.eps == S(2)**-10 + True + >>> half_precision.dig == 3 + True + >>> half_precision.decimal_dig == 5 + True + >>> half_precision.cast_check(1.0) + 1.0 + >>> half_precision.cast_check(1e5) # doctest: +ELLIPSIS + Traceback (most recent call last): + ... + ValueError: Maximum value for data type smaller than new value. + """ + + __slots__ = ('nbits', 'nmant', 'nexp',) + _fields = Type._fields + __slots__ + + _construct_nbits = _construct_nmant = _construct_nexp = Integer + + + @property + def max_exponent(self): + """ The largest positive number n, such that 2**(n - 1) is a representable finite value. """ + # cf. C++'s ``std::numeric_limits::max_exponent`` + return two**(self.nexp - 1) + + @property + def min_exponent(self): + """ The lowest negative number n, such that 2**(n - 1) is a valid normalized number. """ + # cf. C++'s ``std::numeric_limits::min_exponent`` + return 3 - self.max_exponent + + @property + def max(self): + """ Maximum value representable. """ + return (1 - two**-(self.nmant+1))*two**self.max_exponent + + @property + def tiny(self): + """ The minimum positive normalized value. """ + # See C macros: FLT_MIN, DBL_MIN, LDBL_MIN + # or C++'s ``std::numeric_limits::min`` + # or numpy.finfo(dtype).tiny + return two**(self.min_exponent - 1) + + + @property + def eps(self): + """ Difference between 1.0 and the next representable value. """ + return two**(-self.nmant) + + @property + def dig(self): + """ Number of decimal digits that are guaranteed to be preserved in text. + + When converting text -> float -> text, you are guaranteed that at least ``dig`` + number of digits are preserved with respect to rounding or overflow. + """ + from sympy.functions import floor, log + return floor(self.nmant * log(2)/log(10)) + + @property + def decimal_dig(self): + """ Number of digits needed to store & load without loss. + + Explanation + =========== + + Number of decimal digits needed to guarantee that two consecutive conversions + (float -> text -> float) to be idempotent. This is useful when one do not want + to loose precision due to rounding errors when storing a floating point value + as text. + """ + from sympy.functions import ceiling, log + return ceiling((self.nmant + 1) * log(2)/log(10) + 1) + + def cast_nocheck(self, value): + """ Casts without checking if out of bounds or subnormal. """ + if value == oo: # float(oo) or oo + return float(oo) + elif value == -oo: # float(-oo) or -oo + return float(-oo) + return Float(str(sympify(value).evalf(self.decimal_dig)), self.decimal_dig) + + def _check(self, value): + if value < -self.max: + raise ValueError("Value is too small: %d < %d" % (value, -self.max)) + if value > self.max: + raise ValueError("Value is too big: %d > %d" % (value, self.max)) + if abs(value) < self.tiny: + raise ValueError("Smallest (absolute) value for data type bigger than new value.") + +class ComplexBaseType(FloatBaseType): + + __slots__ = () + + def cast_nocheck(self, value): + """ Casts without checking if out of bounds or subnormal. """ + from sympy.functions import re, im + return ( + super().cast_nocheck(re(value)) + + super().cast_nocheck(im(value))*1j + ) + + def _check(self, value): + from sympy.functions import re, im + super()._check(re(value)) + super()._check(im(value)) + + +class ComplexType(ComplexBaseType, FloatType): + """ Represents a complex floating point number. """ + __slots__ = () + + +# NumPy types: +intc = IntBaseType('intc') +intp = IntBaseType('intp') +int8 = SignedIntType('int8', 8) +int16 = SignedIntType('int16', 16) +int32 = SignedIntType('int32', 32) +int64 = SignedIntType('int64', 64) +uint8 = UnsignedIntType('uint8', 8) +uint16 = UnsignedIntType('uint16', 16) +uint32 = UnsignedIntType('uint32', 32) +uint64 = UnsignedIntType('uint64', 64) +float16 = FloatType('float16', 16, nexp=5, nmant=10) # IEEE 754 binary16, Half precision +float32 = FloatType('float32', 32, nexp=8, nmant=23) # IEEE 754 binary32, Single precision +float64 = FloatType('float64', 64, nexp=11, nmant=52) # IEEE 754 binary64, Double precision +float80 = FloatType('float80', 80, nexp=15, nmant=63) # x86 extended precision (1 integer part bit), "long double" +float128 = FloatType('float128', 128, nexp=15, nmant=112) # IEEE 754 binary128, Quadruple precision +float256 = FloatType('float256', 256, nexp=19, nmant=236) # IEEE 754 binary256, Octuple precision + +complex64 = ComplexType('complex64', nbits=64, **float32.kwargs(exclude=('name', 'nbits'))) +complex128 = ComplexType('complex128', nbits=128, **float64.kwargs(exclude=('name', 'nbits'))) + +# Generic types (precision may be chosen by code printers): +untyped = Type('untyped') +real = FloatBaseType('real') +integer = IntBaseType('integer') +complex_ = ComplexBaseType('complex') +bool_ = Type('bool') + + +class Attribute(Token): + """ Attribute (possibly parametrized) + + For use with :class:`sympy.codegen.ast.Node` (which takes instances of + ``Attribute`` as ``attrs``). + + Parameters + ========== + + name : str + parameters : Tuple + + Examples + ======== + + >>> from sympy.codegen.ast import Attribute + >>> volatile = Attribute('volatile') + >>> volatile + volatile + >>> print(repr(volatile)) + Attribute(String('volatile')) + >>> a = Attribute('foo', [1, 2, 3]) + >>> a + foo(1, 2, 3) + >>> a.parameters == (1, 2, 3) + True + """ + __slots__ = _fields = ('name', 'parameters') + defaults = {'parameters': Tuple()} + + _construct_name = String + _construct_parameters = staticmethod(_mk_Tuple) + + def _sympystr(self, printer, *args, **kwargs): + result = str(self.name) + if self.parameters: + result += '(%s)' % ', '.join((printer._print( + arg, *args, **kwargs) for arg in self.parameters)) + return result + +value_const = Attribute('value_const') +pointer_const = Attribute('pointer_const') + + +class Variable(Node): + """ Represents a variable. + + Parameters + ========== + + symbol : Symbol + type : Type (optional) + Type of the variable. + attrs : iterable of Attribute instances + Will be stored as a Tuple. + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.codegen.ast import Variable, float32, integer + >>> x = Symbol('x') + >>> v = Variable(x, type=float32) + >>> v.attrs + () + >>> v == Variable('x') + False + >>> v == Variable('x', type=float32) + True + >>> v + Variable(x, type=float32) + + One may also construct a ``Variable`` instance with the type deduced from + assumptions about the symbol using the ``deduced`` classmethod: + + >>> i = Symbol('i', integer=True) + >>> v = Variable.deduced(i) + >>> v.type == integer + True + >>> v == Variable('i') + False + >>> from sympy.codegen.ast import value_const + >>> value_const in v.attrs + False + >>> w = Variable('w', attrs=[value_const]) + >>> w + Variable(w, attrs=(value_const,)) + >>> value_const in w.attrs + True + >>> w.as_Declaration(value=42) + Declaration(Variable(w, value=42, attrs=(value_const,))) + + """ + + __slots__ = ('symbol', 'type', 'value') + _fields = __slots__ + Node._fields + + defaults = Node.defaults.copy() + defaults.update({'type': untyped, 'value': none}) + + _construct_symbol = staticmethod(sympify) + _construct_value = staticmethod(sympify) + + @classmethod + def deduced(cls, symbol, value=None, attrs=Tuple(), cast_check=True): + """ Alt. constructor with type deduction from ``Type.from_expr``. + + Deduces type primarily from ``symbol``, secondarily from ``value``. + + Parameters + ========== + + symbol : Symbol + value : expr + (optional) value of the variable. + attrs : iterable of Attribute instances + cast_check : bool + Whether to apply ``Type.cast_check`` on ``value``. + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.codegen.ast import Variable, complex_ + >>> n = Symbol('n', integer=True) + >>> str(Variable.deduced(n).type) + 'integer' + >>> x = Symbol('x', real=True) + >>> v = Variable.deduced(x) + >>> v.type + real + >>> z = Symbol('z', complex=True) + >>> Variable.deduced(z).type == complex_ + True + + """ + if isinstance(symbol, Variable): + return symbol + + try: + type_ = Type.from_expr(symbol) + except ValueError: + type_ = Type.from_expr(value) + + if value is not None and cast_check: + value = type_.cast_check(value) + return cls(symbol, type=type_, value=value, attrs=attrs) + + def as_Declaration(self, **kwargs): + """ Convenience method for creating a Declaration instance. + + Explanation + =========== + + If the variable of the Declaration need to wrap a modified + variable keyword arguments may be passed (overriding e.g. + the ``value`` of the Variable instance). + + Examples + ======== + + >>> from sympy.codegen.ast import Variable, NoneToken + >>> x = Variable('x') + >>> decl1 = x.as_Declaration() + >>> # value is special NoneToken() which must be tested with == operator + >>> decl1.variable.value is None # won't work + False + >>> decl1.variable.value == None # not PEP-8 compliant + True + >>> decl1.variable.value == NoneToken() # OK + True + >>> decl2 = x.as_Declaration(value=42.0) + >>> decl2.variable.value == 42.0 + True + + """ + kw = self.kwargs() + kw.update(kwargs) + return Declaration(self.func(**kw)) + + def _relation(self, rhs, op): + try: + rhs = _sympify(rhs) + except SympifyError: + raise TypeError("Invalid comparison %s < %s" % (self, rhs)) + return op(self, rhs, evaluate=False) + + __lt__ = lambda self, other: self._relation(other, Lt) + __le__ = lambda self, other: self._relation(other, Le) + __ge__ = lambda self, other: self._relation(other, Ge) + __gt__ = lambda self, other: self._relation(other, Gt) + +class Pointer(Variable): + """ Represents a pointer. See ``Variable``. + + Examples + ======== + + Can create instances of ``Element``: + + >>> from sympy import Symbol + >>> from sympy.codegen.ast import Pointer + >>> i = Symbol('i', integer=True) + >>> p = Pointer('x') + >>> p[i+1] + Element(x, indices=(i + 1,)) + + """ + __slots__ = () + + def __getitem__(self, key): + try: + return Element(self.symbol, key) + except TypeError: + return Element(self.symbol, (key,)) + + +class Element(Token): + """ Element in (a possibly N-dimensional) array. + + Examples + ======== + + >>> from sympy.codegen.ast import Element + >>> elem = Element('x', 'ijk') + >>> elem.symbol.name == 'x' + True + >>> elem.indices + (i, j, k) + >>> from sympy import ccode + >>> ccode(elem) + 'x[i][j][k]' + >>> ccode(Element('x', 'ijk', strides='lmn', offset='o')) + 'x[i*l + j*m + k*n + o]' + + """ + __slots__ = _fields = ('symbol', 'indices', 'strides', 'offset') + defaults = {'strides': none, 'offset': none} + _construct_symbol = staticmethod(sympify) + _construct_indices = staticmethod(lambda arg: Tuple(*arg)) + _construct_strides = staticmethod(lambda arg: Tuple(*arg)) + _construct_offset = staticmethod(sympify) + + +class Declaration(Token): + """ Represents a variable declaration + + Parameters + ========== + + variable : Variable + + Examples + ======== + + >>> from sympy.codegen.ast import Declaration, NoneToken, untyped + >>> z = Declaration('z') + >>> z.variable.type == untyped + True + >>> # value is special NoneToken() which must be tested with == operator + >>> z.variable.value is None # won't work + False + >>> z.variable.value == None # not PEP-8 compliant + True + >>> z.variable.value == NoneToken() # OK + True + """ + __slots__ = _fields = ('variable',) + _construct_variable = Variable + + +class While(Token): + """ Represents a 'for-loop' in the code. + + Expressions are of the form: + "while condition: + body..." + + Parameters + ========== + + condition : expression convertible to Boolean + body : CodeBlock or iterable + When passed an iterable it is used to instantiate a CodeBlock. + + Examples + ======== + + >>> from sympy import symbols, Gt, Abs + >>> from sympy.codegen import aug_assign, Assignment, While + >>> x, dx = symbols('x dx') + >>> expr = 1 - x**2 + >>> whl = While(Gt(Abs(dx), 1e-9), [ + ... Assignment(dx, -expr/expr.diff(x)), + ... aug_assign(x, '+', dx) + ... ]) + + """ + __slots__ = _fields = ('condition', 'body') + _construct_condition = staticmethod(lambda cond: _sympify(cond)) + + @classmethod + def _construct_body(cls, itr): + if isinstance(itr, CodeBlock): + return itr + else: + return CodeBlock(*itr) + + +class Scope(Token): + """ Represents a scope in the code. + + Parameters + ========== + + body : CodeBlock or iterable + When passed an iterable it is used to instantiate a CodeBlock. + + """ + __slots__ = _fields = ('body',) + + @classmethod + def _construct_body(cls, itr): + if isinstance(itr, CodeBlock): + return itr + else: + return CodeBlock(*itr) + + +class Stream(Token): + """ Represents a stream. + + There are two predefined Stream instances ``stdout`` & ``stderr``. + + Parameters + ========== + + name : str + + Examples + ======== + + >>> from sympy import pycode, Symbol + >>> from sympy.codegen.ast import Print, stderr, QuotedString + >>> print(pycode(Print(['x'], file=stderr))) + print(x, file=sys.stderr) + >>> x = Symbol('x') + >>> print(pycode(Print([QuotedString('x')], file=stderr))) # print literally "x" + print("x", file=sys.stderr) + + """ + __slots__ = _fields = ('name',) + _construct_name = String + +stdout = Stream('stdout') +stderr = Stream('stderr') + + +class Print(Token): + r""" Represents print command in the code. + + Parameters + ========== + + formatstring : str + *args : Basic instances (or convertible to such through sympify) + + Examples + ======== + + >>> from sympy.codegen.ast import Print + >>> from sympy import pycode + >>> print(pycode(Print('x y'.split(), "coordinate: %12.5g %12.5g\\n"))) + print("coordinate: %12.5g %12.5g\n" % (x, y), end="") + + """ + + __slots__ = _fields = ('print_args', 'format_string', 'file') + defaults = {'format_string': none, 'file': none} + + _construct_print_args = staticmethod(_mk_Tuple) + _construct_format_string = QuotedString + _construct_file = Stream + + +class FunctionPrototype(Node): + """ Represents a function prototype + + Allows the user to generate forward declaration in e.g. C/C++. + + Parameters + ========== + + return_type : Type + name : str + parameters: iterable of Variable instances + attrs : iterable of Attribute instances + + Examples + ======== + + >>> from sympy import ccode, symbols + >>> from sympy.codegen.ast import real, FunctionPrototype + >>> x, y = symbols('x y', real=True) + >>> fp = FunctionPrototype(real, 'foo', [x, y]) + >>> ccode(fp) + 'double foo(double x, double y)' + + """ + + __slots__ = ('return_type', 'name', 'parameters') + _fields: tuple[str, ...] = __slots__ + Node._fields + + _construct_return_type = Type + _construct_name = String + + @staticmethod + def _construct_parameters(args): + def _var(arg): + if isinstance(arg, Declaration): + return arg.variable + elif isinstance(arg, Variable): + return arg + else: + return Variable.deduced(arg) + return Tuple(*map(_var, args)) + + @classmethod + def from_FunctionDefinition(cls, func_def): + if not isinstance(func_def, FunctionDefinition): + raise TypeError("func_def is not an instance of FunctionDefinition") + return cls(**func_def.kwargs(exclude=('body',))) + + +class FunctionDefinition(FunctionPrototype): + """ Represents a function definition in the code. + + Parameters + ========== + + return_type : Type + name : str + parameters: iterable of Variable instances + body : CodeBlock or iterable + attrs : iterable of Attribute instances + + Examples + ======== + + >>> from sympy import ccode, symbols + >>> from sympy.codegen.ast import real, FunctionPrototype + >>> x, y = symbols('x y', real=True) + >>> fp = FunctionPrototype(real, 'foo', [x, y]) + >>> ccode(fp) + 'double foo(double x, double y)' + >>> from sympy.codegen.ast import FunctionDefinition, Return + >>> body = [Return(x*y)] + >>> fd = FunctionDefinition.from_FunctionPrototype(fp, body) + >>> print(ccode(fd)) + double foo(double x, double y){ + return x*y; + } + """ + + __slots__ = ('body', ) + _fields = FunctionPrototype._fields[:-1] + __slots__ + Node._fields + + @classmethod + def _construct_body(cls, itr): + if isinstance(itr, CodeBlock): + return itr + else: + return CodeBlock(*itr) + + @classmethod + def from_FunctionPrototype(cls, func_proto, body): + if not isinstance(func_proto, FunctionPrototype): + raise TypeError("func_proto is not an instance of FunctionPrototype") + return cls(body=body, **func_proto.kwargs()) + + +class Return(Token): + """ Represents a return command in the code. + + Parameters + ========== + + return : Basic + + Examples + ======== + + >>> from sympy.codegen.ast import Return + >>> from sympy.printing.pycode import pycode + >>> from sympy import Symbol + >>> x = Symbol('x') + >>> print(pycode(Return(x))) + return x + + """ + __slots__ = _fields = ('return',) + _construct_return=staticmethod(_sympify) + + +class FunctionCall(Token, Expr): + """ Represents a call to a function in the code. + + Parameters + ========== + + name : str + function_args : Tuple + + Examples + ======== + + >>> from sympy.codegen.ast import FunctionCall + >>> from sympy import pycode + >>> fcall = FunctionCall('foo', 'bar baz'.split()) + >>> print(pycode(fcall)) + foo(bar, baz) + + """ + __slots__ = _fields = ('name', 'function_args') + + _construct_name = String + _construct_function_args = staticmethod(lambda args: Tuple(*args)) + + +class Raise(Token): + """ Prints as 'raise ...' in Python, 'throw ...' in C++""" + __slots__ = _fields = ('exception',) + + +class RuntimeError_(Token): + """ Represents 'std::runtime_error' in C++ and 'RuntimeError' in Python. + + Note that the latter is uncommon, and you might want to use e.g. ValueError. + """ + __slots__ = _fields = ('message',) + _construct_message = String diff --git a/.venv/lib/python3.13/site-packages/sympy/codegen/cfunctions.py b/.venv/lib/python3.13/site-packages/sympy/codegen/cfunctions.py new file mode 100644 index 0000000000000000000000000000000000000000..7b79291f128aef1cb83b327782840508e59a9cc8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/codegen/cfunctions.py @@ -0,0 +1,558 @@ +""" +This module contains SymPy functions mathcin corresponding to special math functions in the +C standard library (since C99, also available in C++11). + +The functions defined in this module allows the user to express functions such as ``expm1`` +as a SymPy function for symbolic manipulation. + +""" +from sympy.core.function import ArgumentIndexError, Function +from sympy.core.numbers import Rational +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.logic.boolalg import BooleanFunction, true, false + +def _expm1(x): + return exp(x) - S.One + + +class expm1(Function): + """ + Represents the exponential function minus one. + + Explanation + =========== + + The benefit of using ``expm1(x)`` over ``exp(x) - 1`` + is that the latter is prone to cancellation under finite precision + arithmetic when x is close to zero. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy.codegen.cfunctions import expm1 + >>> '%.0e' % expm1(1e-99).evalf() + '1e-99' + >>> from math import exp + >>> exp(1e-99) - 1 + 0.0 + >>> expm1(x).diff(x) + exp(x) + + See Also + ======== + + log1p + """ + nargs = 1 + + def fdiff(self, argindex=1): + """ + Returns the first derivative of this function. + """ + if argindex == 1: + return exp(*self.args) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_expand_func(self, **hints): + return _expm1(*self.args) + + def _eval_rewrite_as_exp(self, arg, **kwargs): + return exp(arg) - S.One + + _eval_rewrite_as_tractable = _eval_rewrite_as_exp + + @classmethod + def eval(cls, arg): + exp_arg = exp.eval(arg) + if exp_arg is not None: + return exp_arg - S.One + + def _eval_is_real(self): + return self.args[0].is_real + + def _eval_is_finite(self): + return self.args[0].is_finite + + +def _log1p(x): + return log(x + S.One) + + +class log1p(Function): + """ + Represents the natural logarithm of a number plus one. + + Explanation + =========== + + The benefit of using ``log1p(x)`` over ``log(x + 1)`` + is that the latter is prone to cancellation under finite precision + arithmetic when x is close to zero. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy.codegen.cfunctions import log1p + >>> from sympy import expand_log + >>> '%.0e' % expand_log(log1p(1e-99)).evalf() + '1e-99' + >>> from math import log + >>> log(1 + 1e-99) + 0.0 + >>> log1p(x).diff(x) + 1/(x + 1) + + See Also + ======== + + expm1 + """ + nargs = 1 + + + def fdiff(self, argindex=1): + """ + Returns the first derivative of this function. + """ + if argindex == 1: + return S.One/(self.args[0] + S.One) + else: + raise ArgumentIndexError(self, argindex) + + + def _eval_expand_func(self, **hints): + return _log1p(*self.args) + + def _eval_rewrite_as_log(self, arg, **kwargs): + return _log1p(arg) + + _eval_rewrite_as_tractable = _eval_rewrite_as_log + + @classmethod + def eval(cls, arg): + if arg.is_Rational: + return log(arg + S.One) + elif not arg.is_Float: # not safe to add 1 to Float + return log.eval(arg + S.One) + elif arg.is_number: + return log(Rational(arg) + S.One) + + def _eval_is_real(self): + return (self.args[0] + S.One).is_nonnegative + + def _eval_is_finite(self): + if (self.args[0] + S.One).is_zero: + return False + return self.args[0].is_finite + + def _eval_is_positive(self): + return self.args[0].is_positive + + def _eval_is_zero(self): + return self.args[0].is_zero + + def _eval_is_nonnegative(self): + return self.args[0].is_nonnegative + +_Two = S(2) + +def _exp2(x): + return Pow(_Two, x) + +class exp2(Function): + """ + Represents the exponential function with base two. + + Explanation + =========== + + The benefit of using ``exp2(x)`` over ``2**x`` + is that the latter is not as efficient under finite precision + arithmetic. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy.codegen.cfunctions import exp2 + >>> exp2(2).evalf() == 4.0 + True + >>> exp2(x).diff(x) + log(2)*exp2(x) + + See Also + ======== + + log2 + """ + nargs = 1 + + + def fdiff(self, argindex=1): + """ + Returns the first derivative of this function. + """ + if argindex == 1: + return self*log(_Two) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_rewrite_as_Pow(self, arg, **kwargs): + return _exp2(arg) + + _eval_rewrite_as_tractable = _eval_rewrite_as_Pow + + def _eval_expand_func(self, **hints): + return _exp2(*self.args) + + @classmethod + def eval(cls, arg): + if arg.is_number: + return _exp2(arg) + + +def _log2(x): + return log(x)/log(_Two) + + +class log2(Function): + """ + Represents the logarithm function with base two. + + Explanation + =========== + + The benefit of using ``log2(x)`` over ``log(x)/log(2)`` + is that the latter is not as efficient under finite precision + arithmetic. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy.codegen.cfunctions import log2 + >>> log2(4).evalf() == 2.0 + True + >>> log2(x).diff(x) + 1/(x*log(2)) + + See Also + ======== + + exp2 + log10 + """ + nargs = 1 + + def fdiff(self, argindex=1): + """ + Returns the first derivative of this function. + """ + if argindex == 1: + return S.One/(log(_Two)*self.args[0]) + else: + raise ArgumentIndexError(self, argindex) + + + @classmethod + def eval(cls, arg): + if arg.is_number: + result = log.eval(arg, base=_Two) + if result.is_Atom: + return result + elif arg.is_Pow and arg.base == _Two: + return arg.exp + + def _eval_evalf(self, *args, **kwargs): + return self.rewrite(log).evalf(*args, **kwargs) + + def _eval_expand_func(self, **hints): + return _log2(*self.args) + + def _eval_rewrite_as_log(self, arg, **kwargs): + return _log2(arg) + + _eval_rewrite_as_tractable = _eval_rewrite_as_log + + +def _fma(x, y, z): + return x*y + z + + +class fma(Function): + """ + Represents "fused multiply add". + + Explanation + =========== + + The benefit of using ``fma(x, y, z)`` over ``x*y + z`` + is that, under finite precision arithmetic, the former is + supported by special instructions on some CPUs. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy.codegen.cfunctions import fma + >>> fma(x, y, z).diff(x) + y + + """ + nargs = 3 + + def fdiff(self, argindex=1): + """ + Returns the first derivative of this function. + """ + if argindex in (1, 2): + return self.args[2 - argindex] + elif argindex == 3: + return S.One + else: + raise ArgumentIndexError(self, argindex) + + + def _eval_expand_func(self, **hints): + return _fma(*self.args) + + def _eval_rewrite_as_tractable(self, arg, limitvar=None, **kwargs): + return _fma(arg) + + +_Ten = S(10) + + +def _log10(x): + return log(x)/log(_Ten) + + +class log10(Function): + """ + Represents the logarithm function with base ten. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy.codegen.cfunctions import log10 + >>> log10(100).evalf() == 2.0 + True + >>> log10(x).diff(x) + 1/(x*log(10)) + + See Also + ======== + + log2 + """ + nargs = 1 + + def fdiff(self, argindex=1): + """ + Returns the first derivative of this function. + """ + if argindex == 1: + return S.One/(log(_Ten)*self.args[0]) + else: + raise ArgumentIndexError(self, argindex) + + + @classmethod + def eval(cls, arg): + if arg.is_number: + result = log.eval(arg, base=_Ten) + if result.is_Atom: + return result + elif arg.is_Pow and arg.base == _Ten: + return arg.exp + + def _eval_expand_func(self, **hints): + return _log10(*self.args) + + def _eval_rewrite_as_log(self, arg, **kwargs): + return _log10(arg) + + _eval_rewrite_as_tractable = _eval_rewrite_as_log + + +def _Sqrt(x): + return Pow(x, S.Half) + + +class Sqrt(Function): # 'sqrt' already defined in sympy.functions.elementary.miscellaneous + """ + Represents the square root function. + + Explanation + =========== + + The reason why one would use ``Sqrt(x)`` over ``sqrt(x)`` + is that the latter is internally represented as ``Pow(x, S.Half)`` which + may not be what one wants when doing code-generation. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy.codegen.cfunctions import Sqrt + >>> Sqrt(x) + Sqrt(x) + >>> Sqrt(x).diff(x) + 1/(2*sqrt(x)) + + See Also + ======== + + Cbrt + """ + nargs = 1 + + def fdiff(self, argindex=1): + """ + Returns the first derivative of this function. + """ + if argindex == 1: + return Pow(self.args[0], Rational(-1, 2))/_Two + else: + raise ArgumentIndexError(self, argindex) + + def _eval_expand_func(self, **hints): + return _Sqrt(*self.args) + + def _eval_rewrite_as_Pow(self, arg, **kwargs): + return _Sqrt(arg) + + _eval_rewrite_as_tractable = _eval_rewrite_as_Pow + + +def _Cbrt(x): + return Pow(x, Rational(1, 3)) + + +class Cbrt(Function): # 'cbrt' already defined in sympy.functions.elementary.miscellaneous + """ + Represents the cube root function. + + Explanation + =========== + + The reason why one would use ``Cbrt(x)`` over ``cbrt(x)`` + is that the latter is internally represented as ``Pow(x, Rational(1, 3))`` which + may not be what one wants when doing code-generation. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy.codegen.cfunctions import Cbrt + >>> Cbrt(x) + Cbrt(x) + >>> Cbrt(x).diff(x) + 1/(3*x**(2/3)) + + See Also + ======== + + Sqrt + """ + nargs = 1 + + def fdiff(self, argindex=1): + """ + Returns the first derivative of this function. + """ + if argindex == 1: + return Pow(self.args[0], Rational(-_Two/3))/3 + else: + raise ArgumentIndexError(self, argindex) + + + def _eval_expand_func(self, **hints): + return _Cbrt(*self.args) + + def _eval_rewrite_as_Pow(self, arg, **kwargs): + return _Cbrt(arg) + + _eval_rewrite_as_tractable = _eval_rewrite_as_Pow + + +def _hypot(x, y): + return sqrt(Pow(x, 2) + Pow(y, 2)) + + +class hypot(Function): + """ + Represents the hypotenuse function. + + Explanation + =========== + + The hypotenuse function is provided by e.g. the math library + in the C99 standard, hence one may want to represent the function + symbolically when doing code-generation. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy.codegen.cfunctions import hypot + >>> hypot(3, 4).evalf() == 5.0 + True + >>> hypot(x, y) + hypot(x, y) + >>> hypot(x, y).diff(x) + x/hypot(x, y) + + """ + nargs = 2 + + def fdiff(self, argindex=1): + """ + Returns the first derivative of this function. + """ + if argindex in (1, 2): + return 2*self.args[argindex-1]/(_Two*self.func(*self.args)) + else: + raise ArgumentIndexError(self, argindex) + + + def _eval_expand_func(self, **hints): + return _hypot(*self.args) + + def _eval_rewrite_as_Pow(self, arg, **kwargs): + return _hypot(arg) + + _eval_rewrite_as_tractable = _eval_rewrite_as_Pow + + +class isnan(BooleanFunction): + nargs = 1 + + @classmethod + def eval(cls, arg): + if arg is S.NaN: + return true + elif arg.is_number: + return false + else: + return None + + +class isinf(BooleanFunction): + nargs = 1 + + @classmethod + def eval(cls, arg): + if arg.is_infinite: + return true + elif arg.is_finite: + return false + else: + return None diff --git a/.venv/lib/python3.13/site-packages/sympy/codegen/cnodes.py b/.venv/lib/python3.13/site-packages/sympy/codegen/cnodes.py new file mode 100644 index 0000000000000000000000000000000000000000..dd2a324ee49cabcd42e99ca5d8e5379cc9262c4e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/codegen/cnodes.py @@ -0,0 +1,156 @@ +""" +AST nodes specific to the C family of languages +""" + +from sympy.codegen.ast import ( + Attribute, Declaration, Node, String, Token, Type, none, + FunctionCall, CodeBlock + ) +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.sympify import sympify + +void = Type('void') + +restrict = Attribute('restrict') # guarantees no pointer aliasing +volatile = Attribute('volatile') +static = Attribute('static') + + +def alignof(arg): + """ Generate of FunctionCall instance for calling 'alignof' """ + return FunctionCall('alignof', [String(arg) if isinstance(arg, str) else arg]) + + +def sizeof(arg): + """ Generate of FunctionCall instance for calling 'sizeof' + + Examples + ======== + + >>> from sympy.codegen.ast import real + >>> from sympy.codegen.cnodes import sizeof + >>> from sympy import ccode + >>> ccode(sizeof(real)) + 'sizeof(double)' + """ + return FunctionCall('sizeof', [String(arg) if isinstance(arg, str) else arg]) + + +class CommaOperator(Basic): + """ Represents the comma operator in C """ + def __new__(cls, *args): + return Basic.__new__(cls, *[sympify(arg) for arg in args]) + + +class Label(Node): + """ Label for use with e.g. goto statement. + + Examples + ======== + + >>> from sympy import ccode, Symbol + >>> from sympy.codegen.cnodes import Label, PreIncrement + >>> print(ccode(Label('foo'))) + foo: + >>> print(ccode(Label('bar', [PreIncrement(Symbol('a'))]))) + bar: + ++(a); + + """ + __slots__ = _fields = ('name', 'body') + defaults = {'body': none} + _construct_name = String + + @classmethod + def _construct_body(cls, itr): + if isinstance(itr, CodeBlock): + return itr + else: + return CodeBlock(*itr) + + +class goto(Token): + """ Represents goto in C """ + __slots__ = _fields = ('label',) + _construct_label = Label + + +class PreDecrement(Basic): + """ Represents the pre-decrement operator + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy.codegen.cnodes import PreDecrement + >>> from sympy import ccode + >>> ccode(PreDecrement(x)) + '--(x)' + + """ + nargs = 1 + + +class PostDecrement(Basic): + """ Represents the post-decrement operator + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy.codegen.cnodes import PostDecrement + >>> from sympy import ccode + >>> ccode(PostDecrement(x)) + '(x)--' + + """ + nargs = 1 + + +class PreIncrement(Basic): + """ Represents the pre-increment operator + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy.codegen.cnodes import PreIncrement + >>> from sympy import ccode + >>> ccode(PreIncrement(x)) + '++(x)' + + """ + nargs = 1 + + +class PostIncrement(Basic): + """ Represents the post-increment operator + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy.codegen.cnodes import PostIncrement + >>> from sympy import ccode + >>> ccode(PostIncrement(x)) + '(x)++' + + """ + nargs = 1 + + +class struct(Node): + """ Represents a struct in C """ + __slots__ = _fields = ('name', 'declarations') + defaults = {'name': none} + _construct_name = String + + @classmethod + def _construct_declarations(cls, args): + return Tuple(*[Declaration(arg) for arg in args]) + + +class union(struct): + """ Represents a union in C """ + __slots__ = () diff --git a/.venv/lib/python3.13/site-packages/sympy/codegen/cutils.py b/.venv/lib/python3.13/site-packages/sympy/codegen/cutils.py new file mode 100644 index 0000000000000000000000000000000000000000..2182ac1f3455da490a0bb57f8d6731fe8a29a232 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/codegen/cutils.py @@ -0,0 +1,8 @@ +from sympy.printing.c import C99CodePrinter + +def render_as_source_file(content, Printer=C99CodePrinter, settings=None): + """ Renders a C source file (with required #include statements) """ + printer = Printer(settings or {}) + code_str = printer.doprint(content) + includes = '\n'.join(['#include <%s>' % h for h in printer.headers]) + return includes + '\n\n' + code_str diff --git a/.venv/lib/python3.13/site-packages/sympy/codegen/cxxnodes.py b/.venv/lib/python3.13/site-packages/sympy/codegen/cxxnodes.py new file mode 100644 index 0000000000000000000000000000000000000000..7f7aafd01ab2de99ad0f668275889863fc73f5aa --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/codegen/cxxnodes.py @@ -0,0 +1,14 @@ +""" +AST nodes specific to C++. +""" + +from sympy.codegen.ast import Attribute, String, Token, Type, none + +class using(Token): + """ Represents a 'using' statement in C++ """ + __slots__ = _fields = ('type', 'alias') + defaults = {'alias': none} + _construct_type = Type + _construct_alias = String + +constexpr = Attribute('constexpr') diff --git a/.venv/lib/python3.13/site-packages/sympy/codegen/fnodes.py b/.venv/lib/python3.13/site-packages/sympy/codegen/fnodes.py new file mode 100644 index 0000000000000000000000000000000000000000..8b972fcfe4d4de4cfe74d75705f42c5e112a9b43 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/codegen/fnodes.py @@ -0,0 +1,658 @@ +""" +AST nodes specific to Fortran. + +The functions defined in this module allows the user to express functions such as ``dsign`` +as a SymPy function for symbolic manipulation. +""" + +from __future__ import annotations +from sympy.codegen.ast import ( + Attribute, CodeBlock, FunctionCall, Node, none, String, + Token, _mk_Tuple, Variable +) +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.function import Function +from sympy.core.numbers import Float, Integer +from sympy.core.symbol import Str +from sympy.core.sympify import sympify +from sympy.logic import true, false +from sympy.utilities.iterables import iterable + + + +pure = Attribute('pure') +elemental = Attribute('elemental') # (all elemental procedures are also pure) + +intent_in = Attribute('intent_in') +intent_out = Attribute('intent_out') +intent_inout = Attribute('intent_inout') + +allocatable = Attribute('allocatable') + +class Program(Token): + """ Represents a 'program' block in Fortran. + + Examples + ======== + + >>> from sympy.codegen.ast import Print + >>> from sympy.codegen.fnodes import Program + >>> prog = Program('myprogram', [Print([42])]) + >>> from sympy import fcode + >>> print(fcode(prog, source_format='free')) + program myprogram + print *, 42 + end program + + """ + __slots__ = _fields = ('name', 'body') + _construct_name = String + _construct_body = staticmethod(lambda body: CodeBlock(*body)) + + +class use_rename(Token): + """ Represents a renaming in a use statement in Fortran. + + Examples + ======== + + >>> from sympy.codegen.fnodes import use_rename, use + >>> from sympy import fcode + >>> ren = use_rename("thingy", "convolution2d") + >>> print(fcode(ren, source_format='free')) + thingy => convolution2d + >>> full = use('signallib', only=['snr', ren]) + >>> print(fcode(full, source_format='free')) + use signallib, only: snr, thingy => convolution2d + + """ + __slots__ = _fields = ('local', 'original') + _construct_local = String + _construct_original = String + +def _name(arg): + if hasattr(arg, 'name'): + return arg.name + else: + return String(arg) + +class use(Token): + """ Represents a use statement in Fortran. + + Examples + ======== + + >>> from sympy.codegen.fnodes import use + >>> from sympy import fcode + >>> fcode(use('signallib'), source_format='free') + 'use signallib' + >>> fcode(use('signallib', [('metric', 'snr')]), source_format='free') + 'use signallib, metric => snr' + >>> fcode(use('signallib', only=['snr', 'convolution2d']), source_format='free') + 'use signallib, only: snr, convolution2d' + + """ + __slots__ = _fields = ('namespace', 'rename', 'only') + defaults = {'rename': none, 'only': none} + _construct_namespace = staticmethod(_name) + _construct_rename = staticmethod(lambda args: Tuple(*[arg if isinstance(arg, use_rename) else use_rename(*arg) for arg in args])) + _construct_only = staticmethod(lambda args: Tuple(*[arg if isinstance(arg, use_rename) else _name(arg) for arg in args])) + + +class Module(Token): + """ Represents a module in Fortran. + + Examples + ======== + + >>> from sympy.codegen.fnodes import Module + >>> from sympy import fcode + >>> print(fcode(Module('signallib', ['implicit none'], []), source_format='free')) + module signallib + implicit none + + contains + + + end module + + """ + __slots__ = _fields = ('name', 'declarations', 'definitions') + defaults = {'declarations': Tuple()} + _construct_name = String + + @classmethod + def _construct_declarations(cls, args): + args = [Str(arg) if isinstance(arg, str) else arg for arg in args] + return CodeBlock(*args) + + _construct_definitions = staticmethod(lambda arg: CodeBlock(*arg)) + + +class Subroutine(Node): + """ Represents a subroutine in Fortran. + + Examples + ======== + + >>> from sympy import fcode, symbols + >>> from sympy.codegen.ast import Print + >>> from sympy.codegen.fnodes import Subroutine + >>> x, y = symbols('x y', real=True) + >>> sub = Subroutine('mysub', [x, y], [Print([x**2 + y**2, x*y])]) + >>> print(fcode(sub, source_format='free', standard=2003)) + subroutine mysub(x, y) + real*8 :: x + real*8 :: y + print *, x**2 + y**2, x*y + end subroutine + + """ + __slots__ = ('name', 'parameters', 'body') + _fields = __slots__ + Node._fields + _construct_name = String + _construct_parameters = staticmethod(lambda params: Tuple(*map(Variable.deduced, params))) + + @classmethod + def _construct_body(cls, itr): + if isinstance(itr, CodeBlock): + return itr + else: + return CodeBlock(*itr) + +class SubroutineCall(Token): + """ Represents a call to a subroutine in Fortran. + + Examples + ======== + + >>> from sympy.codegen.fnodes import SubroutineCall + >>> from sympy import fcode + >>> fcode(SubroutineCall('mysub', 'x y'.split())) + ' call mysub(x, y)' + + """ + __slots__ = _fields = ('name', 'subroutine_args') + _construct_name = staticmethod(_name) + _construct_subroutine_args = staticmethod(_mk_Tuple) + + +class Do(Token): + """ Represents a Do loop in in Fortran. + + Examples + ======== + + >>> from sympy import fcode, symbols + >>> from sympy.codegen.ast import aug_assign, Print + >>> from sympy.codegen.fnodes import Do + >>> i, n = symbols('i n', integer=True) + >>> r = symbols('r', real=True) + >>> body = [aug_assign(r, '+', 1/i), Print([i, r])] + >>> do1 = Do(body, i, 1, n) + >>> print(fcode(do1, source_format='free')) + do i = 1, n + r = r + 1d0/i + print *, i, r + end do + >>> do2 = Do(body, i, 1, n, 2) + >>> print(fcode(do2, source_format='free')) + do i = 1, n, 2 + r = r + 1d0/i + print *, i, r + end do + + """ + + __slots__ = _fields = ('body', 'counter', 'first', 'last', 'step', 'concurrent') + defaults = {'step': Integer(1), 'concurrent': false} + _construct_body = staticmethod(lambda body: CodeBlock(*body)) + _construct_counter = staticmethod(sympify) + _construct_first = staticmethod(sympify) + _construct_last = staticmethod(sympify) + _construct_step = staticmethod(sympify) + _construct_concurrent = staticmethod(lambda arg: true if arg else false) + + +class ArrayConstructor(Token): + """ Represents an array constructor. + + Examples + ======== + + >>> from sympy import fcode + >>> from sympy.codegen.fnodes import ArrayConstructor + >>> ac = ArrayConstructor([1, 2, 3]) + >>> fcode(ac, standard=95, source_format='free') + '(/1, 2, 3/)' + >>> fcode(ac, standard=2003, source_format='free') + '[1, 2, 3]' + + """ + __slots__ = _fields = ('elements',) + _construct_elements = staticmethod(_mk_Tuple) + + +class ImpliedDoLoop(Token): + """ Represents an implied do loop in Fortran. + + Examples + ======== + + >>> from sympy import Symbol, fcode + >>> from sympy.codegen.fnodes import ImpliedDoLoop, ArrayConstructor + >>> i = Symbol('i', integer=True) + >>> idl = ImpliedDoLoop(i**3, i, -3, 3, 2) # -27, -1, 1, 27 + >>> ac = ArrayConstructor([-28, idl, 28]) # -28, -27, -1, 1, 27, 28 + >>> fcode(ac, standard=2003, source_format='free') + '[-28, (i**3, i = -3, 3, 2), 28]' + + """ + __slots__ = _fields = ('expr', 'counter', 'first', 'last', 'step') + defaults = {'step': Integer(1)} + _construct_expr = staticmethod(sympify) + _construct_counter = staticmethod(sympify) + _construct_first = staticmethod(sympify) + _construct_last = staticmethod(sympify) + _construct_step = staticmethod(sympify) + + +class Extent(Basic): + """ Represents a dimension extent. + + Examples + ======== + + >>> from sympy.codegen.fnodes import Extent + >>> e = Extent(-3, 3) # -3, -2, -1, 0, 1, 2, 3 + >>> from sympy import fcode + >>> fcode(e, source_format='free') + '-3:3' + >>> from sympy.codegen.ast import Variable, real + >>> from sympy.codegen.fnodes import dimension, intent_out + >>> dim = dimension(e, e) + >>> arr = Variable('x', real, attrs=[dim, intent_out]) + >>> fcode(arr.as_Declaration(), source_format='free', standard=2003) + 'real*8, dimension(-3:3, -3:3), intent(out) :: x' + + """ + def __new__(cls, *args): + if len(args) == 2: + low, high = args + return Basic.__new__(cls, sympify(low), sympify(high)) + elif len(args) == 0 or (len(args) == 1 and args[0] in (':', None)): + return Basic.__new__(cls) # assumed shape + else: + raise ValueError("Expected 0 or 2 args (or one argument == None or ':')") + + def _sympystr(self, printer): + if len(self.args) == 0: + return ':' + return ":".join(str(arg) for arg in self.args) + +assumed_extent = Extent() # or Extent(':'), Extent(None) + + +def dimension(*args): + """ Creates a 'dimension' Attribute with (up to 7) extents. + + Examples + ======== + + >>> from sympy import fcode + >>> from sympy.codegen.fnodes import dimension, intent_in + >>> dim = dimension('2', ':') # 2 rows, runtime determined number of columns + >>> from sympy.codegen.ast import Variable, integer + >>> arr = Variable('a', integer, attrs=[dim, intent_in]) + >>> fcode(arr.as_Declaration(), source_format='free', standard=2003) + 'integer*4, dimension(2, :), intent(in) :: a' + + """ + if len(args) > 7: + raise ValueError("Fortran only supports up to 7 dimensional arrays") + parameters = [] + for arg in args: + if isinstance(arg, Extent): + parameters.append(arg) + elif isinstance(arg, str): + if arg == ':': + parameters.append(Extent()) + else: + parameters.append(String(arg)) + elif iterable(arg): + parameters.append(Extent(*arg)) + else: + parameters.append(sympify(arg)) + if len(args) == 0: + raise ValueError("Need at least one dimension") + return Attribute('dimension', parameters) + + +assumed_size = dimension('*') + +def array(symbol, dim, intent=None, *, attrs=(), value=None, type=None): + """ Convenience function for creating a Variable instance for a Fortran array. + + Parameters + ========== + + symbol : symbol + dim : Attribute or iterable + If dim is an ``Attribute`` it need to have the name 'dimension'. If it is + not an ``Attribute``, then it is passed to :func:`dimension` as ``*dim`` + intent : str + One of: 'in', 'out', 'inout' or None + \\*\\*kwargs: + Keyword arguments for ``Variable`` ('type' & 'value') + + Examples + ======== + + >>> from sympy import fcode + >>> from sympy.codegen.ast import integer, real + >>> from sympy.codegen.fnodes import array + >>> arr = array('a', '*', 'in', type=integer) + >>> print(fcode(arr.as_Declaration(), source_format='free', standard=2003)) + integer*4, dimension(*), intent(in) :: a + >>> x = array('x', [3, ':', ':'], intent='out', type=real) + >>> print(fcode(x.as_Declaration(value=1), source_format='free', standard=2003)) + real*8, dimension(3, :, :), intent(out) :: x = 1 + + """ + if isinstance(dim, Attribute): + if str(dim.name) != 'dimension': + raise ValueError("Got an unexpected Attribute argument as dim: %s" % str(dim)) + else: + dim = dimension(*dim) + + attrs = list(attrs) + [dim] + if intent is not None: + if intent not in (intent_in, intent_out, intent_inout): + intent = {'in': intent_in, 'out': intent_out, 'inout': intent_inout}[intent] + attrs.append(intent) + if type is None: + return Variable.deduced(symbol, value=value, attrs=attrs) + else: + return Variable(symbol, type, value=value, attrs=attrs) + +def _printable(arg): + return String(arg) if isinstance(arg, str) else sympify(arg) + + +def allocated(array): + """ Creates an AST node for a function call to Fortran's "allocated(...)" + + Examples + ======== + + >>> from sympy import fcode + >>> from sympy.codegen.fnodes import allocated + >>> alloc = allocated('x') + >>> fcode(alloc, source_format='free') + 'allocated(x)' + + """ + return FunctionCall('allocated', [_printable(array)]) + + +def lbound(array, dim=None, kind=None): + """ Creates an AST node for a function call to Fortran's "lbound(...)" + + Parameters + ========== + + array : Symbol or String + dim : expr + kind : expr + + Examples + ======== + + >>> from sympy import fcode + >>> from sympy.codegen.fnodes import lbound + >>> lb = lbound('arr', dim=2) + >>> fcode(lb, source_format='free') + 'lbound(arr, 2)' + + """ + return FunctionCall( + 'lbound', + [_printable(array)] + + ([_printable(dim)] if dim else []) + + ([_printable(kind)] if kind else []) + ) + + +def ubound(array, dim=None, kind=None): + return FunctionCall( + 'ubound', + [_printable(array)] + + ([_printable(dim)] if dim else []) + + ([_printable(kind)] if kind else []) + ) + + +def shape(source, kind=None): + """ Creates an AST node for a function call to Fortran's "shape(...)" + + Parameters + ========== + + source : Symbol or String + kind : expr + + Examples + ======== + + >>> from sympy import fcode + >>> from sympy.codegen.fnodes import shape + >>> shp = shape('x') + >>> fcode(shp, source_format='free') + 'shape(x)' + + """ + return FunctionCall( + 'shape', + [_printable(source)] + + ([_printable(kind)] if kind else []) + ) + + +def size(array, dim=None, kind=None): + """ Creates an AST node for a function call to Fortran's "size(...)" + + Examples + ======== + + >>> from sympy import fcode, Symbol + >>> from sympy.codegen.ast import FunctionDefinition, real, Return + >>> from sympy.codegen.fnodes import array, sum_, size + >>> a = Symbol('a', real=True) + >>> body = [Return((sum_(a**2)/size(a))**.5)] + >>> arr = array(a, dim=[':'], intent='in') + >>> fd = FunctionDefinition(real, 'rms', [arr], body) + >>> print(fcode(fd, source_format='free', standard=2003)) + real*8 function rms(a) + real*8, dimension(:), intent(in) :: a + rms = sqrt(sum(a**2)*1d0/size(a)) + end function + + """ + return FunctionCall( + 'size', + [_printable(array)] + + ([_printable(dim)] if dim else []) + + ([_printable(kind)] if kind else []) + ) + + +def reshape(source, shape, pad=None, order=None): + """ Creates an AST node for a function call to Fortran's "reshape(...)" + + Parameters + ========== + + source : Symbol or String + shape : ArrayExpr + + """ + return FunctionCall( + 'reshape', + [_printable(source), _printable(shape)] + + ([_printable(pad)] if pad else []) + + ([_printable(order)] if pad else []) + ) + + +def bind_C(name=None): + """ Creates an Attribute ``bind_C`` with a name. + + Parameters + ========== + + name : str + + Examples + ======== + + >>> from sympy import fcode, Symbol + >>> from sympy.codegen.ast import FunctionDefinition, real, Return + >>> from sympy.codegen.fnodes import array, sum_, bind_C + >>> a = Symbol('a', real=True) + >>> s = Symbol('s', integer=True) + >>> arr = array(a, dim=[s], intent='in') + >>> body = [Return((sum_(a**2)/s)**.5)] + >>> fd = FunctionDefinition(real, 'rms', [arr, s], body, attrs=[bind_C('rms')]) + >>> print(fcode(fd, source_format='free', standard=2003)) + real*8 function rms(a, s) bind(C, name="rms") + real*8, dimension(s), intent(in) :: a + integer*4 :: s + rms = sqrt(sum(a**2)/s) + end function + + """ + return Attribute('bind_C', [String(name)] if name else []) + +class GoTo(Token): + """ Represents a goto statement in Fortran + + Examples + ======== + + >>> from sympy.codegen.fnodes import GoTo + >>> go = GoTo([10, 20, 30], 'i') + >>> from sympy import fcode + >>> fcode(go, source_format='free') + 'go to (10, 20, 30), i' + + """ + __slots__ = _fields = ('labels', 'expr') + defaults = {'expr': none} + _construct_labels = staticmethod(_mk_Tuple) + _construct_expr = staticmethod(sympify) + + +class FortranReturn(Token): + """ AST node explicitly mapped to a fortran "return". + + Explanation + =========== + + Because a return statement in fortran is different from C, and + in order to aid reuse of our codegen ASTs the ordinary + ``.codegen.ast.Return`` is interpreted as assignment to + the result variable of the function. If one for some reason needs + to generate a fortran RETURN statement, this node should be used. + + Examples + ======== + + >>> from sympy.codegen.fnodes import FortranReturn + >>> from sympy import fcode + >>> fcode(FortranReturn('x')) + ' return x' + + """ + __slots__ = _fields = ('return_value',) + defaults = {'return_value': none} + _construct_return_value = staticmethod(sympify) + + +class FFunction(Function): + _required_standard = 77 + + def _fcode(self, printer): + name = self.__class__.__name__ + if printer._settings['standard'] < self._required_standard: + raise NotImplementedError("%s requires Fortran %d or newer" % + (name, self._required_standard)) + return '{}({})'.format(name, ', '.join(map(printer._print, self.args))) + + +class F95Function(FFunction): + _required_standard = 95 + + +class isign(FFunction): + """ Fortran sign intrinsic for integer arguments. """ + nargs = 2 + + +class dsign(FFunction): + """ Fortran sign intrinsic for double precision arguments. """ + nargs = 2 + + +class cmplx(FFunction): + """ Fortran complex conversion function. """ + nargs = 2 # may be extended to (2, 3) at a later point + + +class kind(FFunction): + """ Fortran kind function. """ + nargs = 1 + + +class merge(F95Function): + """ Fortran merge function """ + nargs = 3 + + +class _literal(Float): + _token: str + _decimals: int + + def _fcode(self, printer, *args, **kwargs): + mantissa, sgnd_ex = ('%.{}e'.format(self._decimals) % self).split('e') + mantissa = mantissa.strip('0').rstrip('.') + ex_sgn, ex_num = sgnd_ex[0], sgnd_ex[1:].lstrip('0') + ex_sgn = '' if ex_sgn == '+' else ex_sgn + return (mantissa or '0') + self._token + ex_sgn + (ex_num or '0') + + +class literal_sp(_literal): + """ Fortran single precision real literal """ + _token = 'e' + _decimals = 9 + + +class literal_dp(_literal): + """ Fortran double precision real literal """ + _token = 'd' + _decimals = 17 + + +class sum_(Token, Expr): + __slots__ = _fields = ('array', 'dim', 'mask') + defaults = {'dim': none, 'mask': none} + _construct_array = staticmethod(sympify) + _construct_dim = staticmethod(sympify) + + +class product_(Token, Expr): + __slots__ = _fields = ('array', 'dim', 'mask') + defaults = {'dim': none, 'mask': none} + _construct_array = staticmethod(sympify) + _construct_dim = staticmethod(sympify) diff --git a/.venv/lib/python3.13/site-packages/sympy/codegen/futils.py b/.venv/lib/python3.13/site-packages/sympy/codegen/futils.py new file mode 100644 index 0000000000000000000000000000000000000000..4a1f5751fbd4d6b44d99c69a74ad89a8496f8648 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/codegen/futils.py @@ -0,0 +1,40 @@ +from itertools import chain +from sympy.codegen.fnodes import Module +from sympy.core.symbol import Dummy +from sympy.printing.fortran import FCodePrinter + +""" This module collects utilities for rendering Fortran code. """ + + +def render_as_module(definitions, name, declarations=(), printer_settings=None): + """ Creates a ``Module`` instance and renders it as a string. + + This generates Fortran source code for a module with the correct ``use`` statements. + + Parameters + ========== + + definitions : iterable + Passed to :class:`sympy.codegen.fnodes.Module`. + name : str + Passed to :class:`sympy.codegen.fnodes.Module`. + declarations : iterable + Passed to :class:`sympy.codegen.fnodes.Module`. It will be extended with + use statements, 'implicit none' and public list generated from ``definitions``. + printer_settings : dict + Passed to ``FCodePrinter`` (default: ``{'standard': 2003, 'source_format': 'free'}``). + + """ + printer_settings = printer_settings or {'standard': 2003, 'source_format': 'free'} + printer = FCodePrinter(printer_settings) + dummy = Dummy() + if isinstance(definitions, Module): + raise ValueError("This function expects to construct a module on its own.") + mod = Module(name, chain(declarations, [dummy]), definitions) + fstr = printer.doprint(mod) + module_use_str = ' %s\n' % ' \n'.join(['use %s, only: %s' % (k, ', '.join(v)) for + k, v in printer.module_uses.items()]) + module_use_str += ' implicit none\n' + module_use_str += ' private\n' + module_use_str += ' public %s\n' % ', '.join([str(node.name) for node in definitions if getattr(node, 'name', None)]) + return fstr.replace(printer.doprint(dummy), module_use_str) diff --git a/.venv/lib/python3.13/site-packages/sympy/codegen/matrix_nodes.py b/.venv/lib/python3.13/site-packages/sympy/codegen/matrix_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..cf0a13a81c963e2c9e2b885dabd0ff3e2d2b3eb9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/codegen/matrix_nodes.py @@ -0,0 +1,71 @@ +""" +Additional AST nodes for operations on matrices. The nodes in this module +are meant to represent optimization of matrix expressions within codegen's +target languages that cannot be represented by SymPy expressions. + +As an example, we can use :meth:`sympy.codegen.rewriting.optimize` and the +``matin_opt`` optimization provided in :mod:`sympy.codegen.rewriting` to +transform matrix multiplication under certain assumptions: + + >>> from sympy import symbols, MatrixSymbol + >>> n = symbols('n', integer=True) + >>> A = MatrixSymbol('A', n, n) + >>> x = MatrixSymbol('x', n, 1) + >>> expr = A**(-1) * x + >>> from sympy import assuming, Q + >>> from sympy.codegen.rewriting import matinv_opt, optimize + >>> with assuming(Q.fullrank(A)): + ... optimize(expr, [matinv_opt]) + MatrixSolve(A, vector=x) +""" + +from .ast import Token +from sympy.matrices import MatrixExpr +from sympy.core.sympify import sympify + + +class MatrixSolve(Token, MatrixExpr): + """Represents an operation to solve a linear matrix equation. + + Parameters + ========== + + matrix : MatrixSymbol + + Matrix representing the coefficients of variables in the linear + equation. This matrix must be square and full-rank (i.e. all columns must + be linearly independent) for the solving operation to be valid. + + vector : MatrixSymbol + + One-column matrix representing the solutions to the equations + represented in ``matrix``. + + Examples + ======== + + >>> from sympy import symbols, MatrixSymbol + >>> from sympy.codegen.matrix_nodes import MatrixSolve + >>> n = symbols('n', integer=True) + >>> A = MatrixSymbol('A', n, n) + >>> x = MatrixSymbol('x', n, 1) + >>> from sympy.printing.numpy import NumPyPrinter + >>> NumPyPrinter().doprint(MatrixSolve(A, x)) + 'numpy.linalg.solve(A, x)' + >>> from sympy import octave_code + >>> octave_code(MatrixSolve(A, x)) + 'A \\\\ x' + + """ + __slots__ = _fields = ('matrix', 'vector') + + _construct_matrix = staticmethod(sympify) + _construct_vector = staticmethod(sympify) + + @property + def shape(self): + return self.vector.shape + + def _eval_derivative(self, x): + A, b = self.matrix, self.vector + return MatrixSolve(A, b.diff(x) - A.diff(x) * MatrixSolve(A, b)) diff --git a/.venv/lib/python3.13/site-packages/sympy/codegen/numpy_nodes.py b/.venv/lib/python3.13/site-packages/sympy/codegen/numpy_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..c47c87f0e1df74cd314bb70674ebff732fe500aa --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/codegen/numpy_nodes.py @@ -0,0 +1,177 @@ +from sympy.core.function import Add, ArgumentIndexError, Function +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.sorting import default_sort_key +from sympy.core.sympify import sympify +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.miscellaneous import Max, Min +from .ast import Token, none + + +def _logaddexp(x1, x2, *, evaluate=True): + return log(Add(exp(x1, evaluate=evaluate), exp(x2, evaluate=evaluate), evaluate=evaluate)) + + +_two = S.One*2 +_ln2 = log(_two) + + +def _lb(x, *, evaluate=True): + return log(x, evaluate=evaluate)/_ln2 + + +def _exp2(x, *, evaluate=True): + return Pow(_two, x, evaluate=evaluate) + + +def _logaddexp2(x1, x2, *, evaluate=True): + return _lb(Add(_exp2(x1, evaluate=evaluate), + _exp2(x2, evaluate=evaluate), evaluate=evaluate)) + + +class logaddexp(Function): + """ Logarithm of the sum of exponentiations of the inputs. + + Helper class for use with e.g. numpy.logaddexp + + See Also + ======== + + https://numpy.org/doc/stable/reference/generated/numpy.logaddexp.html + """ + nargs = 2 + + def __new__(cls, *args): + return Function.__new__(cls, *sorted(args, key=default_sort_key)) + + def fdiff(self, argindex=1): + """ + Returns the first derivative of this function. + """ + if argindex == 1: + wrt, other = self.args + elif argindex == 2: + other, wrt = self.args + else: + raise ArgumentIndexError(self, argindex) + return S.One/(S.One + exp(other-wrt)) + + def _eval_rewrite_as_log(self, x1, x2, **kwargs): + return _logaddexp(x1, x2) + + def _eval_evalf(self, *args, **kwargs): + return self.rewrite(log).evalf(*args, **kwargs) + + def _eval_simplify(self, *args, **kwargs): + a, b = (x.simplify(**kwargs) for x in self.args) + candidate = _logaddexp(a, b) + if candidate != _logaddexp(a, b, evaluate=False): + return candidate + else: + return logaddexp(a, b) + + +class logaddexp2(Function): + """ Logarithm of the sum of exponentiations of the inputs in base-2. + + Helper class for use with e.g. numpy.logaddexp2 + + See Also + ======== + + https://numpy.org/doc/stable/reference/generated/numpy.logaddexp2.html + """ + nargs = 2 + + def __new__(cls, *args): + return Function.__new__(cls, *sorted(args, key=default_sort_key)) + + def fdiff(self, argindex=1): + """ + Returns the first derivative of this function. + """ + if argindex == 1: + wrt, other = self.args + elif argindex == 2: + other, wrt = self.args + else: + raise ArgumentIndexError(self, argindex) + return S.One/(S.One + _exp2(other-wrt)) + + def _eval_rewrite_as_log(self, x1, x2, **kwargs): + return _logaddexp2(x1, x2) + + def _eval_evalf(self, *args, **kwargs): + return self.rewrite(log).evalf(*args, **kwargs) + + def _eval_simplify(self, *args, **kwargs): + a, b = (x.simplify(**kwargs).factor() for x in self.args) + candidate = _logaddexp2(a, b) + if candidate != _logaddexp2(a, b, evaluate=False): + return candidate + else: + return logaddexp2(a, b) + + +class amin(Token): + """ Minimum value along an axis. + + Helper class for use with e.g. numpy.amin + + + See Also + ======== + + https://numpy.org/doc/stable/reference/generated/numpy.amin.html + """ + __slots__ = _fields = ('array', 'axis') + defaults = {'axis': none} + _construct_axis = staticmethod(sympify) + + +class amax(Token): + """ Maximum value along an axis. + + Helper class for use with e.g. numpy.amax + + + See Also + ======== + + https://numpy.org/doc/stable/reference/generated/numpy.amax.html + """ + __slots__ = _fields = ('array', 'axis') + defaults = {'axis': none} + _construct_axis = staticmethod(sympify) + + +class maximum(Function): + """ Element-wise maximum of array elements. + + Helper class for use with e.g. numpy.maximum + + + See Also + ======== + + https://numpy.org/doc/stable/reference/generated/numpy.maximum.html + """ + + def _eval_rewrite_as_Max(self, *args): + return Max(*self.args) + + +class minimum(Function): + """ Element-wise minimum of array elements. + + Helper class for use with e.g. numpy.minimum + + + See Also + ======== + + https://numpy.org/doc/stable/reference/generated/numpy.minimum.html + """ + + def _eval_rewrite_as_Min(self, *args): + return Min(*self.args) diff --git a/.venv/lib/python3.13/site-packages/sympy/codegen/pynodes.py b/.venv/lib/python3.13/site-packages/sympy/codegen/pynodes.py new file mode 100644 index 0000000000000000000000000000000000000000..f0a08b4a79d32f63d345947d6be310b44504dbf5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/codegen/pynodes.py @@ -0,0 +1,11 @@ +from .abstract_nodes import List as AbstractList +from .ast import Token + + +class List(AbstractList): + pass + + +class NumExprEvaluate(Token): + """represents a call to :class:`numexpr`s :func:`evaluate`""" + __slots__ = _fields = ('expr',) diff --git a/.venv/lib/python3.13/site-packages/sympy/codegen/pyutils.py b/.venv/lib/python3.13/site-packages/sympy/codegen/pyutils.py new file mode 100644 index 0000000000000000000000000000000000000000..e14eabe92ce50105a4055b71a49767aae04610b9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/codegen/pyutils.py @@ -0,0 +1,24 @@ +from sympy.printing.pycode import PythonCodePrinter + +""" This module collects utilities for rendering Python code. """ + + +def render_as_module(content, standard='python3'): + """Renders Python code as a module (with the required imports). + + Parameters + ========== + + standard : + See the parameter ``standard`` in + :meth:`sympy.printing.pycode.pycode` + """ + + printer = PythonCodePrinter({'standard':standard}) + pystr = printer.doprint(content) + if printer._settings['fully_qualified_modules']: + module_imports_str = '\n'.join('import %s' % k for k in printer.module_imports) + else: + module_imports_str = '\n'.join(['from %s import %s' % (k, ', '.join(v)) for + k, v in printer.module_imports.items()]) + return module_imports_str + '\n\n' + pystr diff --git a/.venv/lib/python3.13/site-packages/sympy/codegen/rewriting.py b/.venv/lib/python3.13/site-packages/sympy/codegen/rewriting.py new file mode 100644 index 0000000000000000000000000000000000000000..274b7770b46ded6711468ab2a01db3a53d6fde87 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/codegen/rewriting.py @@ -0,0 +1,357 @@ +""" +Classes and functions useful for rewriting expressions for optimized code +generation. Some languages (or standards thereof), e.g. C99, offer specialized +math functions for better performance and/or precision. + +Using the ``optimize`` function in this module, together with a collection of +rules (represented as instances of ``Optimization``), one can rewrite the +expressions for this purpose:: + + >>> from sympy import Symbol, exp, log + >>> from sympy.codegen.rewriting import optimize, optims_c99 + >>> x = Symbol('x') + >>> optimize(3*exp(2*x) - 3, optims_c99) + 3*expm1(2*x) + >>> optimize(exp(2*x) - 1 - exp(-33), optims_c99) + expm1(2*x) - exp(-33) + >>> optimize(log(3*x + 3), optims_c99) + log1p(x) + log(3) + >>> optimize(log(2*x + 3), optims_c99) + log(2*x + 3) + +The ``optims_c99`` imported above is tuple containing the following instances +(which may be imported from ``sympy.codegen.rewriting``): + +- ``expm1_opt`` +- ``log1p_opt`` +- ``exp2_opt`` +- ``log2_opt`` +- ``log2const_opt`` + + +""" +from sympy.core.function import expand_log +from sympy.core.singleton import S +from sympy.core.symbol import Wild +from sympy.functions.elementary.complexes import sign +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import (Max, Min) +from sympy.functions.elementary.trigonometric import (cos, sin, sinc) +from sympy.assumptions import Q, ask +from sympy.codegen.cfunctions import log1p, log2, exp2, expm1 +from sympy.codegen.matrix_nodes import MatrixSolve +from sympy.core.expr import UnevaluatedExpr +from sympy.core.power import Pow +from sympy.codegen.numpy_nodes import logaddexp, logaddexp2 +from sympy.codegen.scipy_nodes import cosm1, powm1 +from sympy.core.mul import Mul +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.utilities.iterables import sift + + +class Optimization: + """ Abstract base class for rewriting optimization. + + Subclasses should implement ``__call__`` taking an expression + as argument. + + Parameters + ========== + cost_function : callable returning number + priority : number + + """ + def __init__(self, cost_function=None, priority=1): + self.cost_function = cost_function + self.priority=priority + + def cheapest(self, *args): + return min(args, key=self.cost_function) + + +class ReplaceOptim(Optimization): + """ Rewriting optimization calling replace on expressions. + + Explanation + =========== + + The instance can be used as a function on expressions for which + it will apply the ``replace`` method (see + :meth:`sympy.core.basic.Basic.replace`). + + Parameters + ========== + + query : + First argument passed to replace. + value : + Second argument passed to replace. + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.codegen.rewriting import ReplaceOptim + >>> from sympy.codegen.cfunctions import exp2 + >>> x = Symbol('x') + >>> exp2_opt = ReplaceOptim(lambda p: p.is_Pow and p.base == 2, + ... lambda p: exp2(p.exp)) + >>> exp2_opt(2**x) + exp2(x) + + """ + + def __init__(self, query, value, **kwargs): + super().__init__(**kwargs) + self.query = query + self.value = value + + def __call__(self, expr): + return expr.replace(self.query, self.value) + + +def optimize(expr, optimizations): + """ Apply optimizations to an expression. + + Parameters + ========== + + expr : expression + optimizations : iterable of ``Optimization`` instances + The optimizations will be sorted with respect to ``priority`` (highest first). + + Examples + ======== + + >>> from sympy import log, Symbol + >>> from sympy.codegen.rewriting import optims_c99, optimize + >>> x = Symbol('x') + >>> optimize(log(x+3)/log(2) + log(x**2 + 1), optims_c99) + log1p(x**2) + log2(x + 3) + + """ + + for optim in sorted(optimizations, key=lambda opt: opt.priority, reverse=True): + new_expr = optim(expr) + if optim.cost_function is None: + expr = new_expr + else: + expr = optim.cheapest(expr, new_expr) + return expr + + +exp2_opt = ReplaceOptim( + lambda p: p.is_Pow and p.base == 2, + lambda p: exp2(p.exp) +) + + +_d = Wild('d', properties=[lambda x: x.is_Dummy]) +_u = Wild('u', properties=[lambda x: not x.is_number and not x.is_Add]) +_v = Wild('v') +_w = Wild('w') +_n = Wild('n', properties=[lambda x: x.is_number]) + +sinc_opt1 = ReplaceOptim( + sin(_w)/_w, sinc(_w) +) +sinc_opt2 = ReplaceOptim( + sin(_n*_w)/_w, _n*sinc(_n*_w) +) +sinc_opts = (sinc_opt1, sinc_opt2) + +log2_opt = ReplaceOptim(_v*log(_w)/log(2), _v*log2(_w), cost_function=lambda expr: expr.count( + lambda e: ( # division & eval of transcendentals are expensive floating point operations... + e.is_Pow and e.exp.is_negative # division + or (isinstance(e, (log, log2)) and not e.args[0].is_number)) # transcendental + ) +) + +log2const_opt = ReplaceOptim(log(2)*log2(_w), log(_w)) + +logsumexp_2terms_opt = ReplaceOptim( + lambda l: (isinstance(l, log) + and l.args[0].is_Add + and len(l.args[0].args) == 2 + and all(isinstance(t, exp) for t in l.args[0].args)), + lambda l: ( + Max(*[e.args[0] for e in l.args[0].args]) + + log1p(exp(Min(*[e.args[0] for e in l.args[0].args]))) + ) +) + + +class FuncMinusOneOptim(ReplaceOptim): + """Specialization of ReplaceOptim for functions evaluating "f(x) - 1". + + Explanation + =========== + + Numerical functions which go toward one as x go toward zero is often best + implemented by a dedicated function in order to avoid catastrophic + cancellation. One such example is ``expm1(x)`` in the C standard library + which evaluates ``exp(x) - 1``. Such functions preserves many more + significant digits when its argument is much smaller than one, compared + to subtracting one afterwards. + + Parameters + ========== + + func : + The function which is subtracted by one. + func_m_1 : + The specialized function evaluating ``func(x) - 1``. + opportunistic : bool + When ``True``, apply the transformation as long as the magnitude of the + remaining number terms decreases. When ``False``, only apply the + transformation if it completely eliminates the number term. + + Examples + ======== + + >>> from sympy import symbols, exp + >>> from sympy.codegen.rewriting import FuncMinusOneOptim + >>> from sympy.codegen.cfunctions import expm1 + >>> x, y = symbols('x y') + >>> expm1_opt = FuncMinusOneOptim(exp, expm1) + >>> expm1_opt(exp(x) + 2*exp(5*y) - 3) + expm1(x) + 2*expm1(5*y) + + + """ + + def __init__(self, func, func_m_1, opportunistic=True): + weight = 10 # <-- this is an arbitrary number (heuristic) + super().__init__(lambda e: e.is_Add, self.replace_in_Add, + cost_function=lambda expr: expr.count_ops() - weight*expr.count(func_m_1)) + self.func = func + self.func_m_1 = func_m_1 + self.opportunistic = opportunistic + + def _group_Add_terms(self, add): + numbers, non_num = sift(add.args, lambda arg: arg.is_number, binary=True) + numsum = sum(numbers) + terms_with_func, other = sift(non_num, lambda arg: arg.has(self.func), binary=True) + return numsum, terms_with_func, other + + def replace_in_Add(self, e): + """ passed as second argument to Basic.replace(...) """ + numsum, terms_with_func, other_non_num_terms = self._group_Add_terms(e) + if numsum == 0: + return e + substituted, untouched = [], [] + for with_func in terms_with_func: + if with_func.is_Mul: + func, coeff = sift(with_func.args, lambda arg: arg.func == self.func, binary=True) + if len(func) == 1 and len(coeff) == 1: + func, coeff = func[0], coeff[0] + else: + coeff = None + elif with_func.func == self.func: + func, coeff = with_func, S.One + else: + coeff = None + + if coeff is not None and coeff.is_number and sign(coeff) == -sign(numsum): + if self.opportunistic: + do_substitute = abs(coeff+numsum) < abs(numsum) + else: + do_substitute = coeff+numsum == 0 + + if do_substitute: # advantageous substitution + numsum += coeff + substituted.append(coeff*self.func_m_1(*func.args)) + continue + untouched.append(with_func) + + return e.func(numsum, *substituted, *untouched, *other_non_num_terms) + + def __call__(self, expr): + alt1 = super().__call__(expr) + alt2 = super().__call__(expr.factor()) + return self.cheapest(alt1, alt2) + + +expm1_opt = FuncMinusOneOptim(exp, expm1) +cosm1_opt = FuncMinusOneOptim(cos, cosm1) +powm1_opt = FuncMinusOneOptim(Pow, powm1) + +log1p_opt = ReplaceOptim( + lambda e: isinstance(e, log), + lambda l: expand_log(l.replace( + log, lambda arg: log(arg.factor()) + )).replace(log(_u+1), log1p(_u)) +) + +def create_expand_pow_optimization(limit, *, base_req=lambda b: b.is_symbol): + """ Creates an instance of :class:`ReplaceOptim` for expanding ``Pow``. + + Explanation + =========== + + The requirements for expansions are that the base needs to be a symbol + and the exponent needs to be an Integer (and be less than or equal to + ``limit``). + + Parameters + ========== + + limit : int + The highest power which is expanded into multiplication. + base_req : function returning bool + Requirement on base for expansion to happen, default is to return + the ``is_symbol`` attribute of the base. + + Examples + ======== + + >>> from sympy import Symbol, sin + >>> from sympy.codegen.rewriting import create_expand_pow_optimization + >>> x = Symbol('x') + >>> expand_opt = create_expand_pow_optimization(3) + >>> expand_opt(x**5 + x**3) + x**5 + x*x*x + >>> expand_opt(x**5 + x**3 + sin(x)**3) + x**5 + sin(x)**3 + x*x*x + >>> opt2 = create_expand_pow_optimization(3, base_req=lambda b: not b.is_Function) + >>> opt2((x+1)**2 + sin(x)**2) + sin(x)**2 + (x + 1)*(x + 1) + + """ + return ReplaceOptim( + lambda e: e.is_Pow and base_req(e.base) and e.exp.is_Integer and abs(e.exp) <= limit, + lambda p: ( + UnevaluatedExpr(Mul(*([p.base]*+p.exp), evaluate=False)) if p.exp > 0 else + 1/UnevaluatedExpr(Mul(*([p.base]*-p.exp), evaluate=False)) + )) + +# Optimization procedures for turning A**(-1) * x into MatrixSolve(A, x) +def _matinv_predicate(expr): + # TODO: We should be able to support more than 2 elements + if expr.is_MatMul and len(expr.args) == 2: + left, right = expr.args + if left.is_Inverse and right.shape[1] == 1: + inv_arg = left.arg + if isinstance(inv_arg, MatrixSymbol): + return bool(ask(Q.fullrank(left.arg))) + + return False + +def _matinv_transform(expr): + left, right = expr.args + inv_arg = left.arg + return MatrixSolve(inv_arg, right) + + +matinv_opt = ReplaceOptim(_matinv_predicate, _matinv_transform) + + +logaddexp_opt = ReplaceOptim(log(exp(_v)+exp(_w)), logaddexp(_v, _w)) +logaddexp2_opt = ReplaceOptim(log(Pow(2, _v)+Pow(2, _w)), logaddexp2(_v, _w)*log(2)) + +# Collections of optimizations: +optims_c99 = (expm1_opt, log1p_opt, exp2_opt, log2_opt, log2const_opt) + +optims_numpy = optims_c99 + (logaddexp_opt, logaddexp2_opt,) + sinc_opts + +optims_scipy = (cosm1_opt, powm1_opt) diff --git a/.venv/lib/python3.13/site-packages/sympy/codegen/scipy_nodes.py b/.venv/lib/python3.13/site-packages/sympy/codegen/scipy_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..059a853fc8cac6e0b9a1a3c7395dd3a15384dcba --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/codegen/scipy_nodes.py @@ -0,0 +1,79 @@ +from sympy.core.function import Add, ArgumentIndexError, Function +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.trigonometric import cos, sin + + +def _cosm1(x, *, evaluate=True): + return Add(cos(x, evaluate=evaluate), -S.One, evaluate=evaluate) + + +class cosm1(Function): + """ Minus one plus cosine of x, i.e. cos(x) - 1. For use when x is close to zero. + + Helper class for use with e.g. scipy.special.cosm1 + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.cosm1.html + """ + nargs = 1 + + def fdiff(self, argindex=1): + """ + Returns the first derivative of this function. + """ + if argindex == 1: + return -sin(*self.args) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_rewrite_as_cos(self, x, **kwargs): + return _cosm1(x) + + def _eval_evalf(self, *args, **kwargs): + return self.rewrite(cos).evalf(*args, **kwargs) + + def _eval_simplify(self, **kwargs): + x, = self.args + candidate = _cosm1(x.simplify(**kwargs)) + if candidate != _cosm1(x, evaluate=False): + return candidate + else: + return cosm1(x) + + +def _powm1(x, y, *, evaluate=True): + return Add(Pow(x, y, evaluate=evaluate), -S.One, evaluate=evaluate) + + +class powm1(Function): + """ Minus one plus x to the power of y, i.e. x**y - 1. For use when x is close to one or y is close to zero. + + Helper class for use with e.g. scipy.special.powm1 + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.powm1.html + """ + nargs = 2 + + def fdiff(self, argindex=1): + """ + Returns the first derivative of this function. + """ + if argindex == 1: + return Pow(self.args[0], self.args[1])*self.args[1]/self.args[0] + elif argindex == 2: + return log(self.args[0])*Pow(*self.args) + else: + raise ArgumentIndexError(self, argindex) + + def _eval_rewrite_as_Pow(self, x, y, **kwargs): + return _powm1(x, y) + + def _eval_evalf(self, *args, **kwargs): + return self.rewrite(Pow).evalf(*args, **kwargs) + + def _eval_simplify(self, **kwargs): + x, y = self.args + candidate = _powm1(x.simplify(**kwargs), y.simplify(**kwargs)) + if candidate != _powm1(x, y, evaluate=False): + return candidate + else: + return powm1(x, y) diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/__init__.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..220c27a42975555f8fe7770fcef0dd18475c89e1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/__init__.py @@ -0,0 +1,43 @@ +from sympy.combinatorics.permutations import Permutation, Cycle +from sympy.combinatorics.prufer import Prufer +from sympy.combinatorics.generators import cyclic, alternating, symmetric, dihedral +from sympy.combinatorics.subsets import Subset +from sympy.combinatorics.partitions import (Partition, IntegerPartition, + RGS_rank, RGS_unrank, RGS_enum) +from sympy.combinatorics.polyhedron import (Polyhedron, tetrahedron, cube, + octahedron, dodecahedron, icosahedron) +from sympy.combinatorics.perm_groups import PermutationGroup, Coset, SymmetricPermutationGroup +from sympy.combinatorics.group_constructs import DirectProduct +from sympy.combinatorics.graycode import GrayCode +from sympy.combinatorics.named_groups import (SymmetricGroup, DihedralGroup, + CyclicGroup, AlternatingGroup, AbelianGroup, RubikGroup) +from sympy.combinatorics.pc_groups import PolycyclicGroup, Collector +from sympy.combinatorics.free_groups import free_group + +__all__ = [ + 'Permutation', 'Cycle', + + 'Prufer', + + 'cyclic', 'alternating', 'symmetric', 'dihedral', + + 'Subset', + + 'Partition', 'IntegerPartition', 'RGS_rank', 'RGS_unrank', 'RGS_enum', + + 'Polyhedron', 'tetrahedron', 'cube', 'octahedron', 'dodecahedron', + 'icosahedron', + + 'PermutationGroup', 'Coset', 'SymmetricPermutationGroup', + + 'DirectProduct', + + 'GrayCode', + + 'SymmetricGroup', 'DihedralGroup', 'CyclicGroup', 'AlternatingGroup', + 'AbelianGroup', 'RubikGroup', + + 'PolycyclicGroup', 'Collector', + + 'free_group', +] diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/coset_table.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/coset_table.py new file mode 100644 index 0000000000000000000000000000000000000000..0687aa34ec70e9062f65ff6413213848cf03e51b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/coset_table.py @@ -0,0 +1,1259 @@ +from sympy.combinatorics.free_groups import free_group +from sympy.printing.defaults import DefaultPrinting + +from itertools import chain, product +from bisect import bisect_left + + +############################################################################### +# COSET TABLE # +############################################################################### + +class CosetTable(DefaultPrinting): + # coset_table: Mathematically a coset table + # represented using a list of lists + # alpha: Mathematically a coset (precisely, a live coset) + # represented by an integer between i with 1 <= i <= n + # alpha in c + # x: Mathematically an element of "A" (set of generators and + # their inverses), represented using "FpGroupElement" + # fp_grp: Finitely Presented Group with < X|R > as presentation. + # H: subgroup of fp_grp. + # NOTE: We start with H as being only a list of words in generators + # of "fp_grp". Since `.subgroup` method has not been implemented. + + r""" + + Properties + ========== + + [1] `0 \in \Omega` and `\tau(1) = \epsilon` + [2] `\alpha^x = \beta \Leftrightarrow \beta^{x^{-1}} = \alpha` + [3] If `\alpha^x = \beta`, then `H \tau(\alpha)x = H \tau(\beta)` + [4] `\forall \alpha \in \Omega, 1^{\tau(\alpha)} = \alpha` + + References + ========== + + .. [1] Holt, D., Eick, B., O'Brien, E. + "Handbook of Computational Group Theory" + + .. [2] John J. Cannon; Lucien A. Dimino; George Havas; Jane M. Watson + Mathematics of Computation, Vol. 27, No. 123. (Jul., 1973), pp. 463-490. + "Implementation and Analysis of the Todd-Coxeter Algorithm" + + """ + # default limit for the number of cosets allowed in a + # coset enumeration. + coset_table_max_limit = 4096000 + # limit for the current instance + coset_table_limit = None + # maximum size of deduction stack above or equal to + # which it is emptied + max_stack_size = 100 + + def __init__(self, fp_grp, subgroup, max_cosets=None): + if not max_cosets: + max_cosets = CosetTable.coset_table_max_limit + self.fp_group = fp_grp + self.subgroup = subgroup + self.coset_table_limit = max_cosets + # "p" is setup independent of Omega and n + self.p = [0] + # a list of the form `[gen_1, gen_1^{-1}, ... , gen_k, gen_k^{-1}]` + self.A = list(chain.from_iterable((gen, gen**-1) \ + for gen in self.fp_group.generators)) + #P[alpha, x] Only defined when alpha^x is defined. + self.P = [[None]*len(self.A)] + # the mathematical coset table which is a list of lists + self.table = [[None]*len(self.A)] + self.A_dict = {x: self.A.index(x) for x in self.A} + self.A_dict_inv = {} + for x, index in self.A_dict.items(): + if index % 2 == 0: + self.A_dict_inv[x] = self.A_dict[x] + 1 + else: + self.A_dict_inv[x] = self.A_dict[x] - 1 + # used in the coset-table based method of coset enumeration. Each of + # the element is called a "deduction" which is the form (alpha, x) whenever + # a value is assigned to alpha^x during a definition or "deduction process" + self.deduction_stack = [] + # Attributes for modified methods. + H = self.subgroup + self._grp = free_group(', ' .join(["a_%d" % i for i in range(len(H))]))[0] + self.P = [[None]*len(self.A)] + self.p_p = {} + + @property + def omega(self): + """Set of live cosets. """ + return [coset for coset in range(len(self.p)) if self.p[coset] == coset] + + def copy(self): + """ + Return a shallow copy of Coset Table instance ``self``. + + """ + self_copy = self.__class__(self.fp_group, self.subgroup) + self_copy.table = [list(perm_rep) for perm_rep in self.table] + self_copy.p = list(self.p) + self_copy.deduction_stack = list(self.deduction_stack) + return self_copy + + def __str__(self): + return "Coset Table on %s with %s as subgroup generators" \ + % (self.fp_group, self.subgroup) + + __repr__ = __str__ + + @property + def n(self): + """The number `n` represents the length of the sublist containing the + live cosets. + + """ + if not self.table: + return 0 + return max(self.omega) + 1 + + # Pg. 152 [1] + def is_complete(self): + r""" + The coset table is called complete if it has no undefined entries + on the live cosets; that is, `\alpha^x` is defined for all + `\alpha \in \Omega` and `x \in A`. + + """ + return not any(None in self.table[coset] for coset in self.omega) + + # Pg. 153 [1] + def define(self, alpha, x, modified=False): + r""" + This routine is used in the relator-based strategy of Todd-Coxeter + algorithm if some `\alpha^x` is undefined. We check whether there is + space available for defining a new coset. If there is enough space + then we remedy this by adjoining a new coset `\beta` to `\Omega` + (i.e to set of live cosets) and put that equal to `\alpha^x`, then + make an assignment satisfying Property[1]. If there is not enough space + then we halt the Coset Table creation. The maximum amount of space that + can be used by Coset Table can be manipulated using the class variable + ``CosetTable.coset_table_max_limit``. + + See Also + ======== + + define_c + + """ + A = self.A + table = self.table + len_table = len(table) + if len_table >= self.coset_table_limit: + # abort the further generation of cosets + raise ValueError("the coset enumeration has defined more than " + "%s cosets. Try with a greater value max number of cosets " + % self.coset_table_limit) + table.append([None]*len(A)) + self.P.append([None]*len(self.A)) + # beta is the new coset generated + beta = len_table + self.p.append(beta) + table[alpha][self.A_dict[x]] = beta + table[beta][self.A_dict_inv[x]] = alpha + # P[alpha][x] = epsilon, P[beta][x**-1] = epsilon + if modified: + self.P[alpha][self.A_dict[x]] = self._grp.identity + self.P[beta][self.A_dict_inv[x]] = self._grp.identity + self.p_p[beta] = self._grp.identity + + def define_c(self, alpha, x): + r""" + A variation of ``define`` routine, described on Pg. 165 [1], used in + the coset table-based strategy of Todd-Coxeter algorithm. It differs + from ``define`` routine in that for each definition it also adds the + tuple `(\alpha, x)` to the deduction stack. + + See Also + ======== + + define + + """ + A = self.A + table = self.table + len_table = len(table) + if len_table >= self.coset_table_limit: + # abort the further generation of cosets + raise ValueError("the coset enumeration has defined more than " + "%s cosets. Try with a greater value max number of cosets " + % self.coset_table_limit) + table.append([None]*len(A)) + # beta is the new coset generated + beta = len_table + self.p.append(beta) + table[alpha][self.A_dict[x]] = beta + table[beta][self.A_dict_inv[x]] = alpha + # append to deduction stack + self.deduction_stack.append((alpha, x)) + + def scan_c(self, alpha, word): + """ + A variation of ``scan`` routine, described on pg. 165 of [1], which + puts at tuple, whenever a deduction occurs, to deduction stack. + + See Also + ======== + + scan, scan_check, scan_and_fill, scan_and_fill_c + + """ + # alpha is an integer representing a "coset" + # since scanning can be in two cases + # 1. for alpha=0 and w in Y (i.e generating set of H) + # 2. alpha in Omega (set of live cosets), w in R (relators) + A_dict = self.A_dict + A_dict_inv = self.A_dict_inv + table = self.table + f = alpha + i = 0 + r = len(word) + b = alpha + j = r - 1 + # list of union of generators and their inverses + while i <= j and table[f][A_dict[word[i]]] is not None: + f = table[f][A_dict[word[i]]] + i += 1 + if i > j: + if f != b: + self.coincidence_c(f, b) + return + while j >= i and table[b][A_dict_inv[word[j]]] is not None: + b = table[b][A_dict_inv[word[j]]] + j -= 1 + if j < i: + # we have an incorrect completed scan with coincidence f ~ b + # run the "coincidence" routine + self.coincidence_c(f, b) + elif j == i: + # deduction process + table[f][A_dict[word[i]]] = b + table[b][A_dict_inv[word[i]]] = f + self.deduction_stack.append((f, word[i])) + # otherwise scan is incomplete and yields no information + + # alpha, beta coincide, i.e. alpha, beta represent the pair of cosets where + # coincidence occurs + def coincidence_c(self, alpha, beta): + """ + A variation of ``coincidence`` routine used in the coset-table based + method of coset enumeration. The only difference being on addition of + a new coset in coset table(i.e new coset introduction), then it is + appended to ``deduction_stack``. + + See Also + ======== + + coincidence + + """ + A_dict = self.A_dict + A_dict_inv = self.A_dict_inv + table = self.table + # behaves as a queue + q = [] + self.merge(alpha, beta, q) + while len(q) > 0: + gamma = q.pop(0) + for x in A_dict: + delta = table[gamma][A_dict[x]] + if delta is not None: + table[delta][A_dict_inv[x]] = None + # only line of difference from ``coincidence`` routine + self.deduction_stack.append((delta, x**-1)) + mu = self.rep(gamma) + nu = self.rep(delta) + if table[mu][A_dict[x]] is not None: + self.merge(nu, table[mu][A_dict[x]], q) + elif table[nu][A_dict_inv[x]] is not None: + self.merge(mu, table[nu][A_dict_inv[x]], q) + else: + table[mu][A_dict[x]] = nu + table[nu][A_dict_inv[x]] = mu + + def scan(self, alpha, word, y=None, fill=False, modified=False): + r""" + ``scan`` performs a scanning process on the input ``word``. + It first locates the largest prefix ``s`` of ``word`` for which + `\alpha^s` is defined (i.e is not ``None``), ``s`` may be empty. Let + ``word=sv``, let ``t`` be the longest suffix of ``v`` for which + `\alpha^{t^{-1}}` is defined, and let ``v=ut``. Then three + possibilities are there: + + 1. If ``t=v``, then we say that the scan completes, and if, in addition + `\alpha^s = \alpha^{t^{-1}}`, then we say that the scan completes + correctly. + + 2. It can also happen that scan does not complete, but `|u|=1`; that + is, the word ``u`` consists of a single generator `x \in A`. In that + case, if `\alpha^s = \beta` and `\alpha^{t^{-1}} = \gamma`, then we can + set `\beta^x = \gamma` and `\gamma^{x^{-1}} = \beta`. These assignments + are known as deductions and enable the scan to complete correctly. + + 3. See ``coicidence`` routine for explanation of third condition. + + Notes + ===== + + The code for the procedure of scanning `\alpha \in \Omega` + under `w \in A*` is defined on pg. 155 [1] + + See Also + ======== + + scan_c, scan_check, scan_and_fill, scan_and_fill_c + + Scan and Fill + ============= + + Performed when the default argument fill=True. + + Modified Scan + ============= + + Performed when the default argument modified=True + + """ + # alpha is an integer representing a "coset" + # since scanning can be in two cases + # 1. for alpha=0 and w in Y (i.e generating set of H) + # 2. alpha in Omega (set of live cosets), w in R (relators) + A_dict = self.A_dict + A_dict_inv = self.A_dict_inv + table = self.table + f = alpha + i = 0 + r = len(word) + b = alpha + j = r - 1 + b_p = y + if modified: + f_p = self._grp.identity + flag = 0 + while fill or flag == 0: + flag = 1 + while i <= j and table[f][A_dict[word[i]]] is not None: + if modified: + f_p = f_p*self.P[f][A_dict[word[i]]] + f = table[f][A_dict[word[i]]] + i += 1 + if i > j: + if f != b: + if modified: + self.modified_coincidence(f, b, f_p**-1*y) + else: + self.coincidence(f, b) + return + while j >= i and table[b][A_dict_inv[word[j]]] is not None: + if modified: + b_p = b_p*self.P[b][self.A_dict_inv[word[j]]] + b = table[b][A_dict_inv[word[j]]] + j -= 1 + if j < i: + # we have an incorrect completed scan with coincidence f ~ b + # run the "coincidence" routine + if modified: + self.modified_coincidence(f, b, f_p**-1*b_p) + else: + self.coincidence(f, b) + elif j == i: + # deduction process + table[f][A_dict[word[i]]] = b + table[b][A_dict_inv[word[i]]] = f + if modified: + self.P[f][self.A_dict[word[i]]] = f_p**-1*b_p + self.P[b][self.A_dict_inv[word[i]]] = b_p**-1*f_p + return + elif fill: + self.define(f, word[i], modified=modified) + # otherwise scan is incomplete and yields no information + + # used in the low-index subgroups algorithm + def scan_check(self, alpha, word): + r""" + Another version of ``scan`` routine, described on, it checks whether + `\alpha` scans correctly under `word`, it is a straightforward + modification of ``scan``. ``scan_check`` returns ``False`` (rather than + calling ``coincidence``) if the scan completes incorrectly; otherwise + it returns ``True``. + + See Also + ======== + + scan, scan_c, scan_and_fill, scan_and_fill_c + + """ + # alpha is an integer representing a "coset" + # since scanning can be in two cases + # 1. for alpha=0 and w in Y (i.e generating set of H) + # 2. alpha in Omega (set of live cosets), w in R (relators) + A_dict = self.A_dict + A_dict_inv = self.A_dict_inv + table = self.table + f = alpha + i = 0 + r = len(word) + b = alpha + j = r - 1 + while i <= j and table[f][A_dict[word[i]]] is not None: + f = table[f][A_dict[word[i]]] + i += 1 + if i > j: + return f == b + while j >= i and table[b][A_dict_inv[word[j]]] is not None: + b = table[b][A_dict_inv[word[j]]] + j -= 1 + if j < i: + # we have an incorrect completed scan with coincidence f ~ b + # return False, instead of calling coincidence routine + return False + elif j == i: + # deduction process + table[f][A_dict[word[i]]] = b + table[b][A_dict_inv[word[i]]] = f + return True + + def merge(self, k, lamda, q, w=None, modified=False): + """ + Merge two classes with representatives ``k`` and ``lamda``, described + on Pg. 157 [1] (for pseudocode), start by putting ``p[k] = lamda``. + It is more efficient to choose the new representative from the larger + of the two classes being merged, i.e larger among ``k`` and ``lamda``. + procedure ``merge`` performs the merging operation, adds the deleted + class representative to the queue ``q``. + + Parameters + ========== + + 'k', 'lamda' being the two class representatives to be merged. + + Notes + ===== + + Pg. 86-87 [1] contains a description of this method. + + See Also + ======== + + coincidence, rep + + """ + p = self.p + rep = self.rep + phi = rep(k, modified=modified) + psi = rep(lamda, modified=modified) + if phi != psi: + mu = min(phi, psi) + v = max(phi, psi) + p[v] = mu + if modified: + if v == phi: + self.p_p[phi] = self.p_p[k]**-1*w*self.p_p[lamda] + else: + self.p_p[psi] = self.p_p[lamda]**-1*w**-1*self.p_p[k] + q.append(v) + + def rep(self, k, modified=False): + r""" + Parameters + ========== + + `k \in [0 \ldots n-1]`, as for ``self`` only array ``p`` is used + + Returns + ======= + + Representative of the class containing ``k``. + + Returns the representative of `\sim` class containing ``k``, it also + makes some modification to array ``p`` of ``self`` to ease further + computations, described on Pg. 157 [1]. + + The information on classes under `\sim` is stored in array `p` of + ``self`` argument, which will always satisfy the property: + + `p[\alpha] \sim \alpha` and `p[\alpha]=\alpha \iff \alpha=rep(\alpha)` + `\forall \in [0 \ldots n-1]`. + + So, for `\alpha \in [0 \ldots n-1]`, we find `rep(self, \alpha)` by + continually replacing `\alpha` by `p[\alpha]` until it becomes + constant (i.e satisfies `p[\alpha] = \alpha`):w + + To increase the efficiency of later ``rep`` calculations, whenever we + find `rep(self, \alpha)=\beta`, we set + `p[\gamma] = \beta \forall \gamma \in p-chain` from `\alpha` to `\beta` + + Notes + ===== + + ``rep`` routine is also described on Pg. 85-87 [1] in Atkinson's + algorithm, this results from the fact that ``coincidence`` routine + introduces functionality similar to that introduced by the + ``minimal_block`` routine on Pg. 85-87 [1]. + + See Also + ======== + + coincidence, merge + + """ + p = self.p + lamda = k + rho = p[lamda] + if modified: + s = p[:] + while rho != lamda: + if modified: + s[rho] = lamda + lamda = rho + rho = p[lamda] + if modified: + rho = s[lamda] + while rho != k: + mu = rho + rho = s[mu] + p[rho] = lamda + self.p_p[rho] = self.p_p[rho]*self.p_p[mu] + else: + mu = k + rho = p[mu] + while rho != lamda: + p[mu] = lamda + mu = rho + rho = p[mu] + return lamda + + # alpha, beta coincide, i.e. alpha, beta represent the pair of cosets + # where coincidence occurs + def coincidence(self, alpha, beta, w=None, modified=False): + r""" + The third situation described in ``scan`` routine is handled by this + routine, described on Pg. 156-161 [1]. + + The unfortunate situation when the scan completes but not correctly, + then ``coincidence`` routine is run. i.e when for some `i` with + `1 \le i \le r+1`, we have `w=st` with `s = x_1 x_2 \dots x_{i-1}`, + `t = x_i x_{i+1} \dots x_r`, and `\beta = \alpha^s` and + `\gamma = \alpha^{t-1}` are defined but unequal. This means that + `\beta` and `\gamma` represent the same coset of `H` in `G`. Described + on Pg. 156 [1]. ``rep`` + + See Also + ======== + + scan + + """ + A_dict = self.A_dict + A_dict_inv = self.A_dict_inv + table = self.table + # behaves as a queue + q = [] + if modified: + self.modified_merge(alpha, beta, w, q) + else: + self.merge(alpha, beta, q) + while len(q) > 0: + gamma = q.pop(0) + for x in A_dict: + delta = table[gamma][A_dict[x]] + if delta is not None: + table[delta][A_dict_inv[x]] = None + mu = self.rep(gamma, modified=modified) + nu = self.rep(delta, modified=modified) + if table[mu][A_dict[x]] is not None: + if modified: + v = self.p_p[delta]**-1*self.P[gamma][self.A_dict[x]]**-1 + v = v*self.p_p[gamma]*self.P[mu][self.A_dict[x]] + self.modified_merge(nu, table[mu][self.A_dict[x]], v, q) + else: + self.merge(nu, table[mu][A_dict[x]], q) + elif table[nu][A_dict_inv[x]] is not None: + if modified: + v = self.p_p[gamma]**-1*self.P[gamma][self.A_dict[x]] + v = v*self.p_p[delta]*self.P[mu][self.A_dict_inv[x]] + self.modified_merge(mu, table[nu][self.A_dict_inv[x]], v, q) + else: + self.merge(mu, table[nu][A_dict_inv[x]], q) + else: + table[mu][A_dict[x]] = nu + table[nu][A_dict_inv[x]] = mu + if modified: + v = self.p_p[gamma]**-1*self.P[gamma][self.A_dict[x]]*self.p_p[delta] + self.P[mu][self.A_dict[x]] = v + self.P[nu][self.A_dict_inv[x]] = v**-1 + + # method used in the HLT strategy + def scan_and_fill(self, alpha, word): + """ + A modified version of ``scan`` routine used in the relator-based + method of coset enumeration, described on pg. 162-163 [1], which + follows the idea that whenever the procedure is called and the scan + is incomplete then it makes new definitions to enable the scan to + complete; i.e it fills in the gaps in the scan of the relator or + subgroup generator. + + """ + self.scan(alpha, word, fill=True) + + def scan_and_fill_c(self, alpha, word): + """ + A modified version of ``scan`` routine, described on Pg. 165 second + para. [1], with modification similar to that of ``scan_anf_fill`` the + only difference being it calls the coincidence procedure used in the + coset-table based method i.e. the routine ``coincidence_c`` is used. + + See Also + ======== + + scan, scan_and_fill + + """ + A_dict = self.A_dict + A_dict_inv = self.A_dict_inv + table = self.table + r = len(word) + f = alpha + i = 0 + b = alpha + j = r - 1 + # loop until it has filled the alpha row in the table. + while True: + # do the forward scanning + while i <= j and table[f][A_dict[word[i]]] is not None: + f = table[f][A_dict[word[i]]] + i += 1 + if i > j: + if f != b: + self.coincidence_c(f, b) + return + # forward scan was incomplete, scan backwards + while j >= i and table[b][A_dict_inv[word[j]]] is not None: + b = table[b][A_dict_inv[word[j]]] + j -= 1 + if j < i: + self.coincidence_c(f, b) + elif j == i: + table[f][A_dict[word[i]]] = b + table[b][A_dict_inv[word[i]]] = f + self.deduction_stack.append((f, word[i])) + else: + self.define_c(f, word[i]) + + # method used in the HLT strategy + def look_ahead(self): + """ + When combined with the HLT method this is known as HLT+Lookahead + method of coset enumeration, described on pg. 164 [1]. Whenever + ``define`` aborts due to lack of space available this procedure is + executed. This routine helps in recovering space resulting from + "coincidence" of cosets. + + """ + R = self.fp_group.relators + p = self.p + # complete scan all relators under all cosets(obviously live) + # without making new definitions + for beta in self.omega: + for w in R: + self.scan(beta, w) + if p[beta] < beta: + break + + # Pg. 166 + def process_deductions(self, R_c_x, R_c_x_inv): + """ + Processes the deductions that have been pushed onto ``deduction_stack``, + described on Pg. 166 [1] and is used in coset-table based enumeration. + + See Also + ======== + + deduction_stack + + """ + p = self.p + table = self.table + while len(self.deduction_stack) > 0: + if len(self.deduction_stack) >= CosetTable.max_stack_size: + self.look_ahead() + del self.deduction_stack[:] + continue + else: + alpha, x = self.deduction_stack.pop() + if p[alpha] == alpha: + for w in R_c_x: + self.scan_c(alpha, w) + if p[alpha] < alpha: + break + beta = table[alpha][self.A_dict[x]] + if beta is not None and p[beta] == beta: + for w in R_c_x_inv: + self.scan_c(beta, w) + if p[beta] < beta: + break + + def process_deductions_check(self, R_c_x, R_c_x_inv): + """ + A variation of ``process_deductions``, this calls ``scan_check`` + wherever ``process_deductions`` calls ``scan``, described on Pg. [1]. + + See Also + ======== + + process_deductions + + """ + table = self.table + while len(self.deduction_stack) > 0: + alpha, x = self.deduction_stack.pop() + if not all(self.scan_check(alpha, w) for w in R_c_x): + return False + beta = table[alpha][self.A_dict[x]] + if beta is not None: + if not all(self.scan_check(beta, w) for w in R_c_x_inv): + return False + return True + + def switch(self, beta, gamma): + r"""Switch the elements `\beta, \gamma \in \Omega` of ``self``, used + by the ``standardize`` procedure, described on Pg. 167 [1]. + + See Also + ======== + + standardize + + """ + A = self.A + A_dict = self.A_dict + table = self.table + for x in A: + z = table[gamma][A_dict[x]] + table[gamma][A_dict[x]] = table[beta][A_dict[x]] + table[beta][A_dict[x]] = z + for alpha in range(len(self.p)): + if self.p[alpha] == alpha: + if table[alpha][A_dict[x]] == beta: + table[alpha][A_dict[x]] = gamma + elif table[alpha][A_dict[x]] == gamma: + table[alpha][A_dict[x]] = beta + + def standardize(self): + r""" + A coset table is standardized if when running through the cosets and + within each coset through the generator images (ignoring generator + inverses), the cosets appear in order of the integers + `0, 1, \dots, n`. "Standardize" reorders the elements of `\Omega` + such that, if we scan the coset table first by elements of `\Omega` + and then by elements of A, then the cosets occur in ascending order. + ``standardize()`` is used at the end of an enumeration to permute the + cosets so that they occur in some sort of standard order. + + Notes + ===== + + procedure is described on pg. 167-168 [1], it also makes use of the + ``switch`` routine to replace by smaller integer value. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> from sympy.combinatorics.fp_groups import FpGroup, coset_enumeration_r + >>> F, x, y = free_group("x, y") + + # Example 5.3 from [1] + >>> f = FpGroup(F, [x**2*y**2, x**3*y**5]) + >>> C = coset_enumeration_r(f, []) + >>> C.compress() + >>> C.table + [[1, 3, 1, 3], [2, 0, 2, 0], [3, 1, 3, 1], [0, 2, 0, 2]] + >>> C.standardize() + >>> C.table + [[1, 2, 1, 2], [3, 0, 3, 0], [0, 3, 0, 3], [2, 1, 2, 1]] + + """ + A = self.A + A_dict = self.A_dict + gamma = 1 + for alpha, x in product(range(self.n), A): + beta = self.table[alpha][A_dict[x]] + if beta >= gamma: + if beta > gamma: + self.switch(gamma, beta) + gamma += 1 + if gamma == self.n: + return + + # Compression of a Coset Table + def compress(self): + """Removes the non-live cosets from the coset table, described on + pg. 167 [1]. + + """ + gamma = -1 + A = self.A + A_dict = self.A_dict + A_dict_inv = self.A_dict_inv + table = self.table + chi = tuple([i for i in range(len(self.p)) if self.p[i] != i]) + for alpha in self.omega: + gamma += 1 + if gamma != alpha: + # replace alpha by gamma in coset table + for x in A: + beta = table[alpha][A_dict[x]] + table[gamma][A_dict[x]] = beta + # XXX: The line below uses == rather than = which means + # that it has no effect. It is not clear though if it is + # correct simply to delete the line or to change it to + # use =. Changing it causes some tests to fail. + # + # https://github.com/sympy/sympy/issues/27633 + table[beta][A_dict_inv[x]] == gamma # noqa: B015 + # all the cosets in the table are live cosets + self.p = list(range(gamma + 1)) + # delete the useless columns + del table[len(self.p):] + # re-define values + for row in table: + for j in range(len(self.A)): + row[j] -= bisect_left(chi, row[j]) + + def conjugates(self, R): + R_c = list(chain.from_iterable((rel.cyclic_conjugates(), \ + (rel**-1).cyclic_conjugates()) for rel in R)) + R_set = set() + for conjugate in R_c: + R_set = R_set.union(conjugate) + R_c_list = [] + for x in self.A: + r = {word for word in R_set if word[0] == x} + R_c_list.append(r) + R_set.difference_update(r) + return R_c_list + + def coset_representative(self, coset): + ''' + Compute the coset representative of a given coset. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> from sympy.combinatorics.fp_groups import FpGroup, coset_enumeration_r + >>> F, x, y = free_group("x, y") + >>> f = FpGroup(F, [x**3, y**3, x**-1*y**-1*x*y]) + >>> C = coset_enumeration_r(f, [x]) + >>> C.compress() + >>> C.table + [[0, 0, 1, 2], [1, 1, 2, 0], [2, 2, 0, 1]] + >>> C.coset_representative(0) + + >>> C.coset_representative(1) + y + >>> C.coset_representative(2) + y**-1 + + ''' + for x in self.A: + gamma = self.table[coset][self.A_dict[x]] + if coset == 0: + return self.fp_group.identity + if gamma < coset: + return self.coset_representative(gamma)*x**-1 + + ############################## + # Modified Methods # + ############################## + + def modified_define(self, alpha, x): + r""" + Define a function p_p from from [1..n] to A* as + an additional component of the modified coset table. + + Parameters + ========== + + \alpha \in \Omega + x \in A* + + See Also + ======== + + define + + """ + self.define(alpha, x, modified=True) + + def modified_scan(self, alpha, w, y, fill=False): + r""" + Parameters + ========== + \alpha \in \Omega + w \in A* + y \in (YUY^-1) + fill -- `modified_scan_and_fill` when set to True. + + See Also + ======== + + scan + """ + self.scan(alpha, w, y=y, fill=fill, modified=True) + + def modified_scan_and_fill(self, alpha, w, y): + self.modified_scan(alpha, w, y, fill=True) + + def modified_merge(self, k, lamda, w, q): + r""" + Parameters + ========== + + 'k', 'lamda' -- the two class representatives to be merged. + q -- queue of length l of elements to be deleted from `\Omega` *. + w -- Word in (YUY^-1) + + See Also + ======== + + merge + """ + self.merge(k, lamda, q, w=w, modified=True) + + def modified_rep(self, k): + r""" + Parameters + ========== + + `k \in [0 \ldots n-1]` + + See Also + ======== + + rep + """ + self.rep(k, modified=True) + + def modified_coincidence(self, alpha, beta, w): + r""" + Parameters + ========== + + A coincident pair `\alpha, \beta \in \Omega, w \in Y \cup Y^{-1}` + + See Also + ======== + + coincidence + + """ + self.coincidence(alpha, beta, w=w, modified=True) + +############################################################################### +# COSET ENUMERATION # +############################################################################### + +# relator-based method +def coset_enumeration_r(fp_grp, Y, max_cosets=None, draft=None, + incomplete=False, modified=False): + """ + This is easier of the two implemented methods of coset enumeration. + and is often called the HLT method, after Hazelgrove, Leech, Trotter + The idea is that we make use of ``scan_and_fill`` makes new definitions + whenever the scan is incomplete to enable the scan to complete; this way + we fill in the gaps in the scan of the relator or subgroup generator, + that's why the name relator-based method. + + An instance of `CosetTable` for `fp_grp` can be passed as the keyword + argument `draft` in which case the coset enumeration will start with + that instance and attempt to complete it. + + When `incomplete` is `True` and the function is unable to complete for + some reason, the partially complete table will be returned. + + # TODO: complete the docstring + + See Also + ======== + + scan_and_fill, + + Examples + ======== + + >>> from sympy.combinatorics.free_groups import free_group + >>> from sympy.combinatorics.fp_groups import FpGroup, coset_enumeration_r + >>> F, x, y = free_group("x, y") + + # Example 5.1 from [1] + >>> f = FpGroup(F, [x**3, y**3, x**-1*y**-1*x*y]) + >>> C = coset_enumeration_r(f, [x]) + >>> for i in range(len(C.p)): + ... if C.p[i] == i: + ... print(C.table[i]) + [0, 0, 1, 2] + [1, 1, 2, 0] + [2, 2, 0, 1] + >>> C.p + [0, 1, 2, 1, 1] + + # Example from exercises Q2 [1] + >>> f = FpGroup(F, [x**2*y**2, y**-1*x*y*x**-3]) + >>> C = coset_enumeration_r(f, []) + >>> C.compress(); C.standardize() + >>> C.table + [[1, 2, 3, 4], + [5, 0, 6, 7], + [0, 5, 7, 6], + [7, 6, 5, 0], + [6, 7, 0, 5], + [2, 1, 4, 3], + [3, 4, 2, 1], + [4, 3, 1, 2]] + + # Example 5.2 + >>> f = FpGroup(F, [x**2, y**3, (x*y)**3]) + >>> Y = [x*y] + >>> C = coset_enumeration_r(f, Y) + >>> for i in range(len(C.p)): + ... if C.p[i] == i: + ... print(C.table[i]) + [1, 1, 2, 1] + [0, 0, 0, 2] + [3, 3, 1, 0] + [2, 2, 3, 3] + + # Example 5.3 + >>> f = FpGroup(F, [x**2*y**2, x**3*y**5]) + >>> Y = [] + >>> C = coset_enumeration_r(f, Y) + >>> for i in range(len(C.p)): + ... if C.p[i] == i: + ... print(C.table[i]) + [1, 3, 1, 3] + [2, 0, 2, 0] + [3, 1, 3, 1] + [0, 2, 0, 2] + + # Example 5.4 + >>> F, a, b, c, d, e = free_group("a, b, c, d, e") + >>> f = FpGroup(F, [a*b*c**-1, b*c*d**-1, c*d*e**-1, d*e*a**-1, e*a*b**-1]) + >>> Y = [a] + >>> C = coset_enumeration_r(f, Y) + >>> for i in range(len(C.p)): + ... if C.p[i] == i: + ... print(C.table[i]) + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + + # example of "compress" method + >>> C.compress() + >>> C.table + [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] + + # Exercises Pg. 161, Q2. + >>> F, x, y = free_group("x, y") + >>> f = FpGroup(F, [x**2*y**2, y**-1*x*y*x**-3]) + >>> Y = [] + >>> C = coset_enumeration_r(f, Y) + >>> C.compress() + >>> C.standardize() + >>> C.table + [[1, 2, 3, 4], + [5, 0, 6, 7], + [0, 5, 7, 6], + [7, 6, 5, 0], + [6, 7, 0, 5], + [2, 1, 4, 3], + [3, 4, 2, 1], + [4, 3, 1, 2]] + + # John J. Cannon; Lucien A. Dimino; George Havas; Jane M. Watson + # Mathematics of Computation, Vol. 27, No. 123. (Jul., 1973), pp. 463-490 + # from 1973chwd.pdf + # Table 1. Ex. 1 + >>> F, r, s, t = free_group("r, s, t") + >>> E1 = FpGroup(F, [t**-1*r*t*r**-2, r**-1*s*r*s**-2, s**-1*t*s*t**-2]) + >>> C = coset_enumeration_r(E1, [r]) + >>> for i in range(len(C.p)): + ... if C.p[i] == i: + ... print(C.table[i]) + [0, 0, 0, 0, 0, 0] + + Ex. 2 + >>> F, a, b = free_group("a, b") + >>> Cox = FpGroup(F, [a**6, b**6, (a*b)**2, (a**2*b**2)**2, (a**3*b**3)**5]) + >>> C = coset_enumeration_r(Cox, [a]) + >>> index = 0 + >>> for i in range(len(C.p)): + ... if C.p[i] == i: + ... index += 1 + >>> index + 500 + + # Ex. 3 + >>> F, a, b = free_group("a, b") + >>> B_2_4 = FpGroup(F, [a**4, b**4, (a*b)**4, (a**-1*b)**4, (a**2*b)**4, \ + (a*b**2)**4, (a**2*b**2)**4, (a**-1*b*a*b)**4, (a*b**-1*a*b)**4]) + >>> C = coset_enumeration_r(B_2_4, [a]) + >>> index = 0 + >>> for i in range(len(C.p)): + ... if C.p[i] == i: + ... index += 1 + >>> index + 1024 + + References + ========== + + .. [1] Holt, D., Eick, B., O'Brien, E. + "Handbook of computational group theory" + + """ + # 1. Initialize a coset table C for < X|R > + C = CosetTable(fp_grp, Y, max_cosets=max_cosets) + # Define coset table methods. + if modified: + _scan_and_fill = C.modified_scan_and_fill + _define = C.modified_define + else: + _scan_and_fill = C.scan_and_fill + _define = C.define + if draft: + C.table = draft.table[:] + C.p = draft.p[:] + R = fp_grp.relators + A_dict = C.A_dict + p = C.p + for i in range(len(Y)): + if modified: + _scan_and_fill(0, Y[i], C._grp.generators[i]) + else: + _scan_and_fill(0, Y[i]) + alpha = 0 + while alpha < C.n: + if p[alpha] == alpha: + try: + for w in R: + if modified: + _scan_and_fill(alpha, w, C._grp.identity) + else: + _scan_and_fill(alpha, w) + # if alpha was eliminated during the scan then break + if p[alpha] < alpha: + break + if p[alpha] == alpha: + for x in A_dict: + if C.table[alpha][A_dict[x]] is None: + _define(alpha, x) + except ValueError as e: + if incomplete: + return C + raise e + alpha += 1 + return C + +def modified_coset_enumeration_r(fp_grp, Y, max_cosets=None, draft=None, + incomplete=False): + r""" + Introduce a new set of symbols y \in Y that correspond to the + generators of the subgroup. Store the elements of Y as a + word P[\alpha, x] and compute the coset table similar to that of + the regular coset enumeration methods. + + Examples + ======== + + >>> from sympy.combinatorics.free_groups import free_group + >>> from sympy.combinatorics.fp_groups import FpGroup + >>> from sympy.combinatorics.coset_table import modified_coset_enumeration_r + >>> F, x, y = free_group("x, y") + >>> f = FpGroup(F, [x**3, y**3, x**-1*y**-1*x*y]) + >>> C = modified_coset_enumeration_r(f, [x]) + >>> C.table + [[0, 0, 1, 2], [1, 1, 2, 0], [2, 2, 0, 1], [None, 1, None, None], [1, 3, None, None]] + + See Also + ======== + + coset_enumertation_r + + References + ========== + + .. [1] Holt, D., Eick, B., O'Brien, E., + "Handbook of Computational Group Theory", + Section 5.3.2 + """ + return coset_enumeration_r(fp_grp, Y, max_cosets=max_cosets, draft=draft, + incomplete=incomplete, modified=True) + +# Pg. 166 +# coset-table based method +def coset_enumeration_c(fp_grp, Y, max_cosets=None, draft=None, + incomplete=False): + """ + >>> from sympy.combinatorics.free_groups import free_group + >>> from sympy.combinatorics.fp_groups import FpGroup, coset_enumeration_c + >>> F, x, y = free_group("x, y") + >>> f = FpGroup(F, [x**3, y**3, x**-1*y**-1*x*y]) + >>> C = coset_enumeration_c(f, [x]) + >>> C.table + [[0, 0, 1, 2], [1, 1, 2, 0], [2, 2, 0, 1]] + + """ + # Initialize a coset table C for < X|R > + X = fp_grp.generators + R = fp_grp.relators + C = CosetTable(fp_grp, Y, max_cosets=max_cosets) + if draft: + C.table = draft.table[:] + C.p = draft.p[:] + C.deduction_stack = draft.deduction_stack + for alpha, x in product(range(len(C.table)), X): + if C.table[alpha][C.A_dict[x]] is not None: + C.deduction_stack.append((alpha, x)) + A = C.A + # replace all the elements by cyclic reductions + R_cyc_red = [rel.identity_cyclic_reduction() for rel in R] + R_c = list(chain.from_iterable((rel.cyclic_conjugates(), (rel**-1).cyclic_conjugates()) \ + for rel in R_cyc_red)) + R_set = set() + for conjugate in R_c: + R_set = R_set.union(conjugate) + # a list of subsets of R_c whose words start with "x". + R_c_list = [] + for x in C.A: + r = {word for word in R_set if word[0] == x} + R_c_list.append(r) + R_set.difference_update(r) + for w in Y: + C.scan_and_fill_c(0, w) + for x in A: + C.process_deductions(R_c_list[C.A_dict[x]], R_c_list[C.A_dict_inv[x]]) + alpha = 0 + while alpha < len(C.table): + if C.p[alpha] == alpha: + try: + for x in C.A: + if C.p[alpha] != alpha: + break + if C.table[alpha][C.A_dict[x]] is None: + C.define_c(alpha, x) + C.process_deductions(R_c_list[C.A_dict[x]], R_c_list[C.A_dict_inv[x]]) + except ValueError as e: + if incomplete: + return C + raise e + alpha += 1 + return C diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/fp_groups.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/fp_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..95530ccd44f025eca5f029c6ba60d3727cbcb29d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/fp_groups.py @@ -0,0 +1,1352 @@ +"""Finitely Presented Groups and its algorithms. """ + +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.combinatorics.free_groups import (FreeGroup, FreeGroupElement, + free_group) +from sympy.combinatorics.rewritingsystem import RewritingSystem +from sympy.combinatorics.coset_table import (CosetTable, + coset_enumeration_r, + coset_enumeration_c) +from sympy.combinatorics import PermutationGroup +from sympy.matrices.normalforms import invariant_factors +from sympy.matrices import Matrix +from sympy.polys.polytools import gcd +from sympy.printing.defaults import DefaultPrinting +from sympy.utilities import public +from sympy.utilities.magic import pollute + +from itertools import product + + +@public +def fp_group(fr_grp, relators=()): + _fp_group = FpGroup(fr_grp, relators) + return (_fp_group,) + tuple(_fp_group._generators) + +@public +def xfp_group(fr_grp, relators=()): + _fp_group = FpGroup(fr_grp, relators) + return (_fp_group, _fp_group._generators) + +# Does not work. Both symbols and pollute are undefined. Never tested. +@public +def vfp_group(fr_grpm, relators): + _fp_group = FpGroup(symbols, relators) + pollute([sym.name for sym in _fp_group.symbols], _fp_group.generators) + return _fp_group + + +def _parse_relators(rels): + """Parse the passed relators.""" + return rels + + +############################################################################### +# FINITELY PRESENTED GROUPS # +############################################################################### + + +class FpGroup(DefaultPrinting): + """ + The FpGroup would take a FreeGroup and a list/tuple of relators, the + relators would be specified in such a way that each of them be equal to the + identity of the provided free group. + + """ + is_group = True + is_FpGroup = True + is_PermutationGroup = False + + def __init__(self, fr_grp, relators): + relators = _parse_relators(relators) + self.free_group = fr_grp + self.relators = relators + self.generators = self._generators() + self.dtype = type("FpGroupElement", (FpGroupElement,), {"group": self}) + + # CosetTable instance on identity subgroup + self._coset_table = None + # returns whether coset table on identity subgroup + # has been standardized + self._is_standardized = False + + self._order = None + self._center = None + + self._rewriting_system = RewritingSystem(self) + self._perm_isomorphism = None + return + + def _generators(self): + return self.free_group.generators + + def make_confluent(self): + ''' + Try to make the group's rewriting system confluent + + ''' + self._rewriting_system.make_confluent() + return + + def reduce(self, word): + ''' + Return the reduced form of `word` in `self` according to the group's + rewriting system. If it's confluent, the reduced form is the unique normal + form of the word in the group. + + ''' + return self._rewriting_system.reduce(word) + + def equals(self, word1, word2): + ''' + Compare `word1` and `word2` for equality in the group + using the group's rewriting system. If the system is + confluent, the returned answer is necessarily correct. + (If it is not, `False` could be returned in some cases + where in fact `word1 == word2`) + + ''' + if self.reduce(word1*word2**-1) == self.identity: + return True + elif self._rewriting_system.is_confluent: + return False + return None + + @property + def identity(self): + return self.free_group.identity + + def __contains__(self, g): + return g in self.free_group + + def subgroup(self, gens, C=None, homomorphism=False): + ''' + Return the subgroup generated by `gens` using the + Reidemeister-Schreier algorithm + homomorphism -- When set to True, return a dictionary containing the images + of the presentation generators in the original group. + + Examples + ======== + + >>> from sympy.combinatorics.fp_groups import FpGroup + >>> from sympy.combinatorics import free_group + >>> F, x, y = free_group("x, y") + >>> f = FpGroup(F, [x**3, y**5, (x*y)**2]) + >>> H = [x*y, x**-1*y**-1*x*y*x] + >>> K, T = f.subgroup(H, homomorphism=True) + >>> T(K.generators) + [x*y, x**-1*y**2*x**-1] + + ''' + + if not all(isinstance(g, FreeGroupElement) for g in gens): + raise ValueError("Generators must be `FreeGroupElement`s") + if not all(g.group == self.free_group for g in gens): + raise ValueError("Given generators are not members of the group") + if homomorphism: + g, rels, _gens = reidemeister_presentation(self, gens, C=C, homomorphism=True) + else: + g, rels = reidemeister_presentation(self, gens, C=C) + if g: + g = FpGroup(g[0].group, rels) + else: + g = FpGroup(free_group('')[0], []) + if homomorphism: + from sympy.combinatorics.homomorphisms import homomorphism + return g, homomorphism(g, self, g.generators, _gens, check=False) + return g + + def coset_enumeration(self, H, strategy="relator_based", max_cosets=None, + draft=None, incomplete=False): + """ + Return an instance of ``coset table``, when Todd-Coxeter algorithm is + run over the ``self`` with ``H`` as subgroup, using ``strategy`` + argument as strategy. The returned coset table is compressed but not + standardized. + + An instance of `CosetTable` for `fp_grp` can be passed as the keyword + argument `draft` in which case the coset enumeration will start with + that instance and attempt to complete it. + + When `incomplete` is `True` and the function is unable to complete for + some reason, the partially complete table will be returned. + + """ + if not max_cosets: + max_cosets = CosetTable.coset_table_max_limit + if strategy == 'relator_based': + C = coset_enumeration_r(self, H, max_cosets=max_cosets, + draft=draft, incomplete=incomplete) + else: + C = coset_enumeration_c(self, H, max_cosets=max_cosets, + draft=draft, incomplete=incomplete) + if C.is_complete(): + C.compress() + return C + + def standardize_coset_table(self): + """ + Standardized the coset table ``self`` and makes the internal variable + ``_is_standardized`` equal to ``True``. + + """ + self._coset_table.standardize() + self._is_standardized = True + + def coset_table(self, H, strategy="relator_based", max_cosets=None, + draft=None, incomplete=False): + """ + Return the mathematical coset table of ``self`` in ``H``. + + """ + if not H: + if self._coset_table is not None: + if not self._is_standardized: + self.standardize_coset_table() + else: + C = self.coset_enumeration([], strategy, max_cosets=max_cosets, + draft=draft, incomplete=incomplete) + self._coset_table = C + self.standardize_coset_table() + return self._coset_table.table + else: + C = self.coset_enumeration(H, strategy, max_cosets=max_cosets, + draft=draft, incomplete=incomplete) + C.standardize() + return C.table + + def order(self, strategy="relator_based"): + """ + Returns the order of the finitely presented group ``self``. It uses + the coset enumeration with identity group as subgroup, i.e ``H=[]``. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> from sympy.combinatorics.fp_groups import FpGroup + >>> F, x, y = free_group("x, y") + >>> f = FpGroup(F, [x, y**2]) + >>> f.order(strategy="coset_table_based") + 2 + + """ + if self._order is not None: + return self._order + if self._coset_table is not None: + self._order = len(self._coset_table.table) + elif len(self.relators) == 0: + self._order = self.free_group.order() + elif len(self.generators) == 1: + self._order = abs(gcd([r.array_form[0][1] for r in self.relators])) + elif self._is_infinite(): + self._order = S.Infinity + else: + gens, C = self._finite_index_subgroup() + if C: + ind = len(C.table) + self._order = ind*self.subgroup(gens, C=C).order() + else: + self._order = self.index([]) + return self._order + + def _is_infinite(self): + ''' + Test if the group is infinite. Return `True` if the test succeeds + and `None` otherwise + + ''' + used_gens = set() + for r in self.relators: + used_gens.update(r.contains_generators()) + if not set(self.generators) <= used_gens: + return True + # Abelianisation test: check is the abelianisation is infinite + abelian_rels = [] + for rel in self.relators: + abelian_rels.append([rel.exponent_sum(g) for g in self.generators]) + m = Matrix(Matrix(abelian_rels)) + if 0 in invariant_factors(m): + return True + else: + return None + + + def _finite_index_subgroup(self, s=None): + ''' + Find the elements of `self` that generate a finite index subgroup + and, if found, return the list of elements and the coset table of `self` by + the subgroup, otherwise return `(None, None)` + + ''' + gen = self.most_frequent_generator() + rels = list(self.generators) + rels.extend(self.relators) + if not s: + if len(self.generators) == 2: + s = [gen] + [g for g in self.generators if g != gen] + else: + rand = self.free_group.identity + i = 0 + while ((rand in rels or rand**-1 in rels or rand.is_identity) + and i<10): + rand = self.random() + i += 1 + s = [gen, rand] + [g for g in self.generators if g != gen] + mid = (len(s)+1)//2 + half1 = s[:mid] + half2 = s[mid:] + draft1 = None + draft2 = None + m = 200 + C = None + while not C and (m/2 < CosetTable.coset_table_max_limit): + m = min(m, CosetTable.coset_table_max_limit) + draft1 = self.coset_enumeration(half1, max_cosets=m, + draft=draft1, incomplete=True) + if draft1.is_complete(): + C = draft1 + half = half1 + else: + draft2 = self.coset_enumeration(half2, max_cosets=m, + draft=draft2, incomplete=True) + if draft2.is_complete(): + C = draft2 + half = half2 + if not C: + m *= 2 + if not C: + return None, None + C.compress() + return half, C + + def most_frequent_generator(self): + gens = self.generators + rels = self.relators + freqs = [sum(r.generator_count(g) for r in rels) for g in gens] + return gens[freqs.index(max(freqs))] + + def random(self): + import random + r = self.free_group.identity + for i in range(random.randint(2,3)): + r = r*random.choice(self.generators)**random.choice([1,-1]) + return r + + def index(self, H, strategy="relator_based"): + """ + Return the index of subgroup ``H`` in group ``self``. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> from sympy.combinatorics.fp_groups import FpGroup + >>> F, x, y = free_group("x, y") + >>> f = FpGroup(F, [x**5, y**4, y*x*y**3*x**3]) + >>> f.index([x]) + 4 + + """ + # TODO: use |G:H| = |G|/|H| (currently H can't be made into a group) + # when we know |G| and |H| + + if H == []: + return self.order() + else: + C = self.coset_enumeration(H, strategy) + return len(C.table) + + def __str__(self): + if self.free_group.rank > 30: + str_form = "" % self.free_group.rank + else: + str_form = "" % str(self.generators) + return str_form + + __repr__ = __str__ + +#============================================================================== +# PERMUTATION GROUP METHODS +#============================================================================== + + def _to_perm_group(self): + ''' + Return an isomorphic permutation group and the isomorphism. + The implementation is dependent on coset enumeration so + will only terminate for finite groups. + + ''' + from sympy.combinatorics import Permutation + from sympy.combinatorics.homomorphisms import homomorphism + if self.order() is S.Infinity: + raise NotImplementedError("Permutation presentation of infinite " + "groups is not implemented") + if self._perm_isomorphism: + T = self._perm_isomorphism + P = T.image() + else: + C = self.coset_table([]) + gens = self.generators + images = [[C[i][2*gens.index(g)] for i in range(len(C))] for g in gens] + images = [Permutation(i) for i in images] + P = PermutationGroup(images) + T = homomorphism(self, P, gens, images, check=False) + self._perm_isomorphism = T + return P, T + + def _perm_group_list(self, method_name, *args): + ''' + Given the name of a `PermutationGroup` method (returning a subgroup + or a list of subgroups) and (optionally) additional arguments it takes, + return a list or a list of lists containing the generators of this (or + these) subgroups in terms of the generators of `self`. + + ''' + P, T = self._to_perm_group() + perm_result = getattr(P, method_name)(*args) + single = False + if isinstance(perm_result, PermutationGroup): + perm_result, single = [perm_result], True + result = [] + for group in perm_result: + gens = group.generators + result.append(T.invert(gens)) + return result[0] if single else result + + def derived_series(self): + ''' + Return the list of lists containing the generators + of the subgroups in the derived series of `self`. + + ''' + return self._perm_group_list('derived_series') + + def lower_central_series(self): + ''' + Return the list of lists containing the generators + of the subgroups in the lower central series of `self`. + + ''' + return self._perm_group_list('lower_central_series') + + def center(self): + ''' + Return the list of generators of the center of `self`. + + ''' + return self._perm_group_list('center') + + + def derived_subgroup(self): + ''' + Return the list of generators of the derived subgroup of `self`. + + ''' + return self._perm_group_list('derived_subgroup') + + + def centralizer(self, other): + ''' + Return the list of generators of the centralizer of `other` + (a list of elements of `self`) in `self`. + + ''' + T = self._to_perm_group()[1] + other = T(other) + return self._perm_group_list('centralizer', other) + + def normal_closure(self, other): + ''' + Return the list of generators of the normal closure of `other` + (a list of elements of `self`) in `self`. + + ''' + T = self._to_perm_group()[1] + other = T(other) + return self._perm_group_list('normal_closure', other) + + def _perm_property(self, attr): + ''' + Given an attribute of a `PermutationGroup`, return + its value for a permutation group isomorphic to `self`. + + ''' + P = self._to_perm_group()[0] + return getattr(P, attr) + + @property + def is_abelian(self): + ''' + Check if `self` is abelian. + + ''' + return self._perm_property("is_abelian") + + @property + def is_nilpotent(self): + ''' + Check if `self` is nilpotent. + + ''' + return self._perm_property("is_nilpotent") + + @property + def is_solvable(self): + ''' + Check if `self` is solvable. + + ''' + return self._perm_property("is_solvable") + + @property + def elements(self): + ''' + List the elements of `self`. + + ''' + P, T = self._to_perm_group() + return T.invert(P.elements) + + @property + def is_cyclic(self): + """ + Return ``True`` if group is Cyclic. + + """ + if len(self.generators) <= 1: + return True + try: + P, T = self._to_perm_group() + except NotImplementedError: + raise NotImplementedError("Check for infinite Cyclic group " + "is not implemented") + return P.is_cyclic + + def abelian_invariants(self): + """ + Return Abelian Invariants of a group. + """ + try: + P, T = self._to_perm_group() + except NotImplementedError: + raise NotImplementedError("abelian invariants is not implemented" + "for infinite group") + return P.abelian_invariants() + + def composition_series(self): + """ + Return subnormal series of maximum length for a group. + """ + try: + P, T = self._to_perm_group() + except NotImplementedError: + raise NotImplementedError("composition series is not implemented" + "for infinite group") + return P.composition_series() + + +class FpSubgroup(DefaultPrinting): + ''' + The class implementing a subgroup of an FpGroup or a FreeGroup + (only finite index subgroups are supported at this point). This + is to be used if one wishes to check if an element of the original + group belongs to the subgroup + + ''' + def __init__(self, G, gens, normal=False): + super().__init__() + self.parent = G + self.generators = list({g for g in gens if g != G.identity}) + self._min_words = None #for use in __contains__ + self.C = None + self.normal = normal + + def __contains__(self, g): + + if isinstance(self.parent, FreeGroup): + if self._min_words is None: + # make _min_words - a list of subwords such that + # g is in the subgroup if and only if it can be + # partitioned into these subwords. Infinite families of + # subwords are presented by tuples, e.g. (r, w) + # stands for the family of subwords r*w**n*r**-1 + + def _process(w): + # this is to be used before adding new words + # into _min_words; if the word w is not cyclically + # reduced, it will generate an infinite family of + # subwords so should be written as a tuple; + # if it is, w**-1 should be added to the list + # as well + p, r = w.cyclic_reduction(removed=True) + if not r.is_identity: + return [(r, p)] + else: + return [w, w**-1] + + # make the initial list + gens = [] + for w in self.generators: + if self.normal: + w = w.cyclic_reduction() + gens.extend(_process(w)) + + for w1 in gens: + for w2 in gens: + # if w1 and w2 are equal or are inverses, continue + if w1 == w2 or (not isinstance(w1, tuple) + and w1**-1 == w2): + continue + + # if the start of one word is the inverse of the + # end of the other, their multiple should be added + # to _min_words because of cancellation + if isinstance(w1, tuple): + # start, end + s1, s2 = w1[0][0], w1[0][0]**-1 + else: + s1, s2 = w1[0], w1[len(w1)-1] + + if isinstance(w2, tuple): + # start, end + r1, r2 = w2[0][0], w2[0][0]**-1 + else: + r1, r2 = w2[0], w2[len(w1)-1] + + # p1 and p2 are w1 and w2 or, in case when + # w1 or w2 is an infinite family, a representative + p1, p2 = w1, w2 + if isinstance(w1, tuple): + p1 = w1[0]*w1[1]*w1[0]**-1 + if isinstance(w2, tuple): + p2 = w2[0]*w2[1]*w2[0]**-1 + + # add the product of the words to the list is necessary + if r1**-1 == s2 and not (p1*p2).is_identity: + new = _process(p1*p2) + if new not in gens: + gens.extend(new) + + if r2**-1 == s1 and not (p2*p1).is_identity: + new = _process(p2*p1) + if new not in gens: + gens.extend(new) + + self._min_words = gens + + min_words = self._min_words + + def _is_subword(w): + # check if w is a word in _min_words or one of + # the infinite families in it + w, r = w.cyclic_reduction(removed=True) + if r.is_identity or self.normal: + return w in min_words + else: + t = [s[1] for s in min_words if isinstance(s, tuple) + and s[0] == r] + return [s for s in t if w.power_of(s)] != [] + + # store the solution of words for which the result of + # _word_break (below) is known + known = {} + + def _word_break(w): + # check if w can be written as a product of words + # in min_words + if len(w) == 0: + return True + i = 0 + while i < len(w): + i += 1 + prefix = w.subword(0, i) + if not _is_subword(prefix): + continue + rest = w.subword(i, len(w)) + if rest not in known: + known[rest] = _word_break(rest) + if known[rest]: + return True + return False + + if self.normal: + g = g.cyclic_reduction() + return _word_break(g) + else: + if self.C is None: + C = self.parent.coset_enumeration(self.generators) + self.C = C + i = 0 + C = self.C + for j in range(len(g)): + i = C.table[i][C.A_dict[g[j]]] + return i == 0 + + def order(self): + if not self.generators: + return S.One + if isinstance(self.parent, FreeGroup): + return S.Infinity + if self.C is None: + C = self.parent.coset_enumeration(self.generators) + self.C = C + # This is valid because `len(self.C.table)` (the index of the subgroup) + # will always be finite - otherwise coset enumeration doesn't terminate + return self.parent.order()/len(self.C.table) + + def to_FpGroup(self): + if isinstance(self.parent, FreeGroup): + gen_syms = [('x_%d'%i) for i in range(len(self.generators))] + return free_group(', '.join(gen_syms))[0] + return self.parent.subgroup(C=self.C) + + def __str__(self): + if len(self.generators) > 30: + str_form = "" % len(self.generators) + else: + str_form = "" % str(self.generators) + return str_form + + __repr__ = __str__ + + +############################################################################### +# LOW INDEX SUBGROUPS # +############################################################################### + +def low_index_subgroups(G, N, Y=()): + """ + Implements the Low Index Subgroups algorithm, i.e find all subgroups of + ``G`` upto a given index ``N``. This implements the method described in + [Sim94]. This procedure involves a backtrack search over incomplete Coset + Tables, rather than over forced coincidences. + + Parameters + ========== + + G: An FpGroup < X|R > + N: positive integer, representing the maximum index value for subgroups + Y: (an optional argument) specifying a list of subgroup generators, such + that each of the resulting subgroup contains the subgroup generated by Y. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> from sympy.combinatorics.fp_groups import FpGroup, low_index_subgroups + >>> F, x, y = free_group("x, y") + >>> f = FpGroup(F, [x**2, y**3, (x*y)**4]) + >>> L = low_index_subgroups(f, 4) + >>> for coset_table in L: + ... print(coset_table.table) + [[0, 0, 0, 0]] + [[0, 0, 1, 2], [1, 1, 2, 0], [3, 3, 0, 1], [2, 2, 3, 3]] + [[0, 0, 1, 2], [2, 2, 2, 0], [1, 1, 0, 1]] + [[1, 1, 0, 0], [0, 0, 1, 1]] + + References + ========== + + .. [1] Holt, D., Eick, B., O'Brien, E. + "Handbook of Computational Group Theory" + Section 5.4 + + .. [2] Marston Conder and Peter Dobcsanyi + "Applications and Adaptions of the Low Index Subgroups Procedure" + + """ + C = CosetTable(G, []) + R = G.relators + # length chosen for the length of the short relators + len_short_rel = 5 + # elements of R2 only checked at the last step for complete + # coset tables + R2 = {rel for rel in R if len(rel) > len_short_rel} + # elements of R1 are used in inner parts of the process to prune + # branches of the search tree, + R1 = {rel.identity_cyclic_reduction() for rel in set(R) - R2} + R1_c_list = C.conjugates(R1) + S = [] + descendant_subgroups(S, C, R1_c_list, C.A[0], R2, N, Y) + return S + + +def descendant_subgroups(S, C, R1_c_list, x, R2, N, Y): + A_dict = C.A_dict + A_dict_inv = C.A_dict_inv + if C.is_complete(): + # if C is complete then it only needs to test + # whether the relators in R2 are satisfied + for w, alpha in product(R2, C.omega): + if not C.scan_check(alpha, w): + return + # relators in R2 are satisfied, append the table to list + S.append(C) + else: + # find the first undefined entry in Coset Table + for alpha, x in product(range(len(C.table)), C.A): + if C.table[alpha][A_dict[x]] is None: + # this is "x" in pseudo-code (using "y" makes it clear) + undefined_coset, undefined_gen = alpha, x + break + # for filling up the undefine entry we try all possible values + # of beta in Omega or beta = n where beta^(undefined_gen^-1) is undefined + reach = C.omega + [C.n] + for beta in reach: + if beta < N: + if beta == C.n or C.table[beta][A_dict_inv[undefined_gen]] is None: + try_descendant(S, C, R1_c_list, R2, N, undefined_coset, \ + undefined_gen, beta, Y) + + +def try_descendant(S, C, R1_c_list, R2, N, alpha, x, beta, Y): + r""" + Solves the problem of trying out each individual possibility + for `\alpha^x. + + """ + D = C.copy() + if beta == D.n and beta < N: + D.table.append([None]*len(D.A)) + D.p.append(beta) + D.table[alpha][D.A_dict[x]] = beta + D.table[beta][D.A_dict_inv[x]] = alpha + D.deduction_stack.append((alpha, x)) + if not D.process_deductions_check(R1_c_list[D.A_dict[x]], \ + R1_c_list[D.A_dict_inv[x]]): + return + for w in Y: + if not D.scan_check(0, w): + return + if first_in_class(D, Y): + descendant_subgroups(S, D, R1_c_list, x, R2, N, Y) + + +def first_in_class(C, Y=()): + """ + Checks whether the subgroup ``H=G1`` corresponding to the Coset Table + could possibly be the canonical representative of its conjugacy class. + + Parameters + ========== + + C: CosetTable + + Returns + ======= + + bool: True/False + + If this returns False, then no descendant of C can have that property, and + so we can abandon C. If it returns True, then we need to process further + the node of the search tree corresponding to C, and so we call + ``descendant_subgroups`` recursively on C. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> from sympy.combinatorics.fp_groups import FpGroup, CosetTable, first_in_class + >>> F, x, y = free_group("x, y") + >>> f = FpGroup(F, [x**2, y**3, (x*y)**4]) + >>> C = CosetTable(f, []) + >>> C.table = [[0, 0, None, None]] + >>> first_in_class(C) + True + >>> C.table = [[1, 1, 1, None], [0, 0, None, 1]]; C.p = [0, 1] + >>> first_in_class(C) + True + >>> C.table = [[1, 1, 2, 1], [0, 0, 0, None], [None, None, None, 0]] + >>> C.p = [0, 1, 2] + >>> first_in_class(C) + False + >>> C.table = [[1, 1, 1, 2], [0, 0, 2, 0], [2, None, 0, 1]] + >>> first_in_class(C) + False + + # TODO:: Sims points out in [Sim94] that performance can be improved by + # remembering some of the information computed by ``first_in_class``. If + # the ``continue alpha`` statement is executed at line 14, then the same thing + # will happen for that value of alpha in any descendant of the table C, and so + # the values the values of alpha for which this occurs could profitably be + # stored and passed through to the descendants of C. Of course this would + # make the code more complicated. + + # The code below is taken directly from the function on page 208 of [Sim94] + # nu[alpha] + + """ + n = C.n + # lamda is the largest numbered point in Omega_c_alpha which is currently defined + lamda = -1 + # for alpha in Omega_c, nu[alpha] is the point in Omega_c_alpha corresponding to alpha + nu = [None]*n + # for alpha in Omega_c_alpha, mu[alpha] is the point in Omega_c corresponding to alpha + mu = [None]*n + # mutually nu and mu are the mutually-inverse equivalence maps between + # Omega_c_alpha and Omega_c + next_alpha = False + # For each 0!=alpha in [0 .. nc-1], we start by constructing the equivalent + # standardized coset table C_alpha corresponding to H_alpha + for alpha in range(1, n): + # reset nu to "None" after previous value of alpha + for beta in range(lamda+1): + nu[mu[beta]] = None + # we only want to reject our current table in favour of a preceding + # table in the ordering in which 1 is replaced by alpha, if the subgroup + # G_alpha corresponding to this preceding table definitely contains the + # given subgroup + for w in Y: + # TODO: this should support input of a list of general words + # not just the words which are in "A" (i.e gen and gen^-1) + if C.table[alpha][C.A_dict[w]] != alpha: + # continue with alpha + next_alpha = True + break + if next_alpha: + next_alpha = False + continue + # try alpha as the new point 0 in Omega_C_alpha + mu[0] = alpha + nu[alpha] = 0 + # compare corresponding entries in C and C_alpha + lamda = 0 + for beta in range(n): + for x in C.A: + gamma = C.table[beta][C.A_dict[x]] + delta = C.table[mu[beta]][C.A_dict[x]] + # if either of the entries is undefined, + # we move with next alpha + if gamma is None or delta is None: + # continue with alpha + next_alpha = True + break + if nu[delta] is None: + # delta becomes the next point in Omega_C_alpha + lamda += 1 + nu[delta] = lamda + mu[lamda] = delta + if nu[delta] < gamma: + return False + if nu[delta] > gamma: + # continue with alpha + next_alpha = True + break + if next_alpha: + next_alpha = False + break + return True + +#======================================================================== +# Simplifying Presentation +#======================================================================== + +def simplify_presentation(*args, change_gens=False): + ''' + For an instance of `FpGroup`, return a simplified isomorphic copy of + the group (e.g. remove redundant generators or relators). Alternatively, + a list of generators and relators can be passed in which case the + simplified lists will be returned. + + By default, the generators of the group are unchanged. If you would + like to remove redundant generators, set the keyword argument + `change_gens = True`. + + ''' + if len(args) == 1: + if not isinstance(args[0], FpGroup): + raise TypeError("The argument must be an instance of FpGroup") + G = args[0] + gens, rels = simplify_presentation(G.generators, G.relators, + change_gens=change_gens) + if gens: + return FpGroup(gens[0].group, rels) + return FpGroup(FreeGroup([]), []) + elif len(args) == 2: + gens, rels = args[0][:], args[1][:] + if not gens: + return gens, rels + identity = gens[0].group.identity + else: + if len(args) == 0: + m = "Not enough arguments" + else: + m = "Too many arguments" + raise RuntimeError(m) + + prev_gens = [] + prev_rels = [] + while not set(prev_rels) == set(rels): + prev_rels = rels + while change_gens and not set(prev_gens) == set(gens): + prev_gens = gens + gens, rels = elimination_technique_1(gens, rels, identity) + rels = _simplify_relators(rels) + + if change_gens: + syms = [g.array_form[0][0] for g in gens] + F = free_group(syms)[0] + identity = F.identity + gens = F.generators + subs = dict(zip(syms, gens)) + for j, r in enumerate(rels): + a = r.array_form + rel = identity + for sym, p in a: + rel = rel*subs[sym]**p + rels[j] = rel + return gens, rels + +def _simplify_relators(rels): + """ + Simplifies a set of relators. All relators are checked to see if they are + of the form `gen^n`. If any such relators are found then all other relators + are processed for strings in the `gen` known order. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> from sympy.combinatorics.fp_groups import _simplify_relators + >>> F, x, y = free_group("x, y") + >>> w1 = [x**2*y**4, x**3] + >>> _simplify_relators(w1) + [x**3, x**-1*y**4] + + >>> w2 = [x**2*y**-4*x**5, x**3, x**2*y**8, y**5] + >>> _simplify_relators(w2) + [x**-1*y**-2, x**-1*y*x**-1, x**3, y**5] + + >>> w3 = [x**6*y**4, x**4] + >>> _simplify_relators(w3) + [x**4, x**2*y**4] + + >>> w4 = [x**2, x**5, y**3] + >>> _simplify_relators(w4) + [x, y**3] + + """ + rels = rels[:] + + if not rels: + return [] + + identity = rels[0].group.identity + + # build dictionary with "gen: n" where gen^n is one of the relators + exps = {} + for i in range(len(rels)): + rel = rels[i] + if rel.number_syllables() == 1: + g = rel[0] + exp = abs(rel.array_form[0][1]) + if rel.array_form[0][1] < 0: + rels[i] = rels[i]**-1 + g = g**-1 + if g in exps: + exp = gcd(exp, exps[g].array_form[0][1]) + exps[g] = g**exp + + one_syllables_words = list(exps.values()) + # decrease some of the exponents in relators, making use of the single + # syllable relators + for i, rel in enumerate(rels): + if rel in one_syllables_words: + continue + rel = rel.eliminate_words(one_syllables_words, _all = True) + # if rels[i] contains g**n where abs(n) is greater than half of the power p + # of g in exps, g**n can be replaced by g**(n-p) (or g**(p-n) if n<0) + for g in rel.contains_generators(): + if g in exps: + exp = exps[g].array_form[0][1] + max_exp = (exp + 1)//2 + rel = rel.eliminate_word(g**(max_exp), g**(max_exp-exp), _all = True) + rel = rel.eliminate_word(g**(-max_exp), g**(-(max_exp-exp)), _all = True) + rels[i] = rel + + rels = [r.identity_cyclic_reduction() for r in rels] + + rels += one_syllables_words # include one_syllable_words in the list of relators + rels = list(set(rels)) # get unique values in rels + rels.sort() + + # remove entries in rels + try: + rels.remove(identity) + except ValueError: + pass + return rels + +# Pg 350, section 2.5.1 from [2] +def elimination_technique_1(gens, rels, identity): + rels = rels[:] + # the shorter relators are examined first so that generators selected for + # elimination will have shorter strings as equivalent + rels.sort() + gens = gens[:] + redundant_gens = {} + redundant_rels = [] + used_gens = set() + # examine each relator in relator list for any generator occurring exactly + # once + for rel in rels: + # don't look for a redundant generator in a relator which + # depends on previously found ones + contained_gens = rel.contains_generators() + if any(g in contained_gens for g in redundant_gens): + continue + contained_gens = list(contained_gens) + contained_gens.sort(reverse = True) + for gen in contained_gens: + if rel.generator_count(gen) == 1 and gen not in used_gens: + k = rel.exponent_sum(gen) + gen_index = rel.index(gen**k) + bk = rel.subword(gen_index + 1, len(rel)) + fw = rel.subword(0, gen_index) + chi = bk*fw + redundant_gens[gen] = chi**(-1*k) + used_gens.update(chi.contains_generators()) + redundant_rels.append(rel) + break + rels = [r for r in rels if r not in redundant_rels] + # eliminate the redundant generators from remaining relators + rels = [r.eliminate_words(redundant_gens, _all = True).identity_cyclic_reduction() for r in rels] + rels = list(set(rels)) + try: + rels.remove(identity) + except ValueError: + pass + gens = [g for g in gens if g not in redundant_gens] + return gens, rels + +############################################################################### +# SUBGROUP PRESENTATIONS # +############################################################################### + +# Pg 175 [1] +def define_schreier_generators(C, homomorphism=False): + ''' + Parameters + ========== + + C -- Coset table. + homomorphism -- When set to True, return a dictionary containing the images + of the presentation generators in the original group. + ''' + y = [] + gamma = 1 + f = C.fp_group + X = f.generators + if homomorphism: + # `_gens` stores the elements of the parent group to + # to which the schreier generators correspond to. + _gens = {} + # compute the schreier Traversal + tau = {} + tau[0] = f.identity + C.P = [[None]*len(C.A) for i in range(C.n)] + for alpha, x in product(C.omega, C.A): + beta = C.table[alpha][C.A_dict[x]] + if beta == gamma: + C.P[alpha][C.A_dict[x]] = "" + C.P[beta][C.A_dict_inv[x]] = "" + gamma += 1 + if homomorphism: + tau[beta] = tau[alpha]*x + elif x in X and C.P[alpha][C.A_dict[x]] is None: + y_alpha_x = '%s_%s' % (x, alpha) + y.append(y_alpha_x) + C.P[alpha][C.A_dict[x]] = y_alpha_x + if homomorphism: + _gens[y_alpha_x] = tau[alpha]*x*tau[beta]**-1 + grp_gens = list(free_group(', '.join(y))) + C._schreier_free_group = grp_gens.pop(0) + C._schreier_generators = grp_gens + if homomorphism: + C._schreier_gen_elem = _gens + # replace all elements of P by, free group elements + for i, j in product(range(len(C.P)), range(len(C.A))): + # if equals "", replace by identity element + if C.P[i][j] == "": + C.P[i][j] = C._schreier_free_group.identity + elif isinstance(C.P[i][j], str): + r = C._schreier_generators[y.index(C.P[i][j])] + C.P[i][j] = r + beta = C.table[i][j] + C.P[beta][j + 1] = r**-1 + +def reidemeister_relators(C): + R = C.fp_group.relators + rels = [rewrite(C, coset, word) for word in R for coset in range(C.n)] + order_1_gens = {i for i in rels if len(i) == 1} + + # remove all the order 1 generators from relators + rels = list(filter(lambda rel: rel not in order_1_gens, rels)) + + # replace order 1 generators by identity element in reidemeister relators + for i in range(len(rels)): + w = rels[i] + w = w.eliminate_words(order_1_gens, _all=True) + rels[i] = w + + C._schreier_generators = [i for i in C._schreier_generators + if not (i in order_1_gens or i**-1 in order_1_gens)] + + # Tietze transformation 1 i.e TT_1 + # remove cyclic conjugate elements from relators + i = 0 + while i < len(rels): + w = rels[i] + j = i + 1 + while j < len(rels): + if w.is_cyclic_conjugate(rels[j]): + del rels[j] + else: + j += 1 + i += 1 + + C._reidemeister_relators = rels + + +def rewrite(C, alpha, w): + """ + Parameters + ========== + + C: CosetTable + alpha: A live coset + w: A word in `A*` + + Returns + ======= + + rho(tau(alpha), w) + + Examples + ======== + + >>> from sympy.combinatorics.fp_groups import FpGroup, CosetTable, define_schreier_generators, rewrite + >>> from sympy.combinatorics import free_group + >>> F, x, y = free_group("x, y") + >>> f = FpGroup(F, [x**2, y**3, (x*y)**6]) + >>> C = CosetTable(f, []) + >>> C.table = [[1, 1, 2, 3], [0, 0, 4, 5], [4, 4, 3, 0], [5, 5, 0, 2], [2, 2, 5, 1], [3, 3, 1, 4]] + >>> C.p = [0, 1, 2, 3, 4, 5] + >>> define_schreier_generators(C) + >>> rewrite(C, 0, (x*y)**6) + x_4*y_2*x_3*x_1*x_2*y_4*x_5 + + """ + v = C._schreier_free_group.identity + for i in range(len(w)): + x_i = w[i] + v = v*C.P[alpha][C.A_dict[x_i]] + alpha = C.table[alpha][C.A_dict[x_i]] + return v + +# Pg 350, section 2.5.2 from [2] +def elimination_technique_2(C): + """ + This technique eliminates one generator at a time. Heuristically this + seems superior in that we may select for elimination the generator with + shortest equivalent string at each stage. + + >>> from sympy.combinatorics import free_group + >>> from sympy.combinatorics.fp_groups import FpGroup, coset_enumeration_r, \ + reidemeister_relators, define_schreier_generators, elimination_technique_2 + >>> F, x, y = free_group("x, y") + >>> f = FpGroup(F, [x**3, y**5, (x*y)**2]); H = [x*y, x**-1*y**-1*x*y*x] + >>> C = coset_enumeration_r(f, H) + >>> C.compress(); C.standardize() + >>> define_schreier_generators(C) + >>> reidemeister_relators(C) + >>> elimination_technique_2(C) + ([y_1, y_2], [y_2**-3, y_2*y_1*y_2*y_1*y_2*y_1, y_1**2]) + + """ + rels = C._reidemeister_relators + rels.sort(reverse=True) + gens = C._schreier_generators + for i in range(len(gens) - 1, -1, -1): + rel = rels[i] + for j in range(len(gens) - 1, -1, -1): + gen = gens[j] + if rel.generator_count(gen) == 1: + k = rel.exponent_sum(gen) + gen_index = rel.index(gen**k) + bk = rel.subword(gen_index + 1, len(rel)) + fw = rel.subword(0, gen_index) + rep_by = (bk*fw)**(-1*k) + del rels[i] + del gens[j] + rels = [rel.eliminate_word(gen, rep_by) for rel in rels] + break + C._reidemeister_relators = rels + C._schreier_generators = gens + return C._schreier_generators, C._reidemeister_relators + +def reidemeister_presentation(fp_grp, H, C=None, homomorphism=False): + """ + Parameters + ========== + + fp_group: A finitely presented group, an instance of FpGroup + H: A subgroup whose presentation is to be found, given as a list + of words in generators of `fp_grp` + homomorphism: When set to True, return a homomorphism from the subgroup + to the parent group + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> from sympy.combinatorics.fp_groups import FpGroup, reidemeister_presentation + >>> F, x, y = free_group("x, y") + + Example 5.6 Pg. 177 from [1] + >>> f = FpGroup(F, [x**3, y**5, (x*y)**2]) + >>> H = [x*y, x**-1*y**-1*x*y*x] + >>> reidemeister_presentation(f, H) + ((y_1, y_2), (y_1**2, y_2**3, y_2*y_1*y_2*y_1*y_2*y_1)) + + Example 5.8 Pg. 183 from [1] + >>> f = FpGroup(F, [x**3, y**3, (x*y)**3]) + >>> H = [x*y, x*y**-1] + >>> reidemeister_presentation(f, H) + ((x_0, y_0), (x_0**3, y_0**3, x_0*y_0*x_0*y_0*x_0*y_0)) + + Exercises Q2. Pg 187 from [1] + >>> f = FpGroup(F, [x**2*y**2, y**-1*x*y*x**-3]) + >>> H = [x] + >>> reidemeister_presentation(f, H) + ((x_0,), (x_0**4,)) + + Example 5.9 Pg. 183 from [1] + >>> f = FpGroup(F, [x**3*y**-3, (x*y)**3, (x*y**-1)**2]) + >>> H = [x] + >>> reidemeister_presentation(f, H) + ((x_0,), (x_0**6,)) + + """ + if not C: + C = coset_enumeration_r(fp_grp, H) + C.compress(); C.standardize() + define_schreier_generators(C, homomorphism=homomorphism) + reidemeister_relators(C) + gens, rels = C._schreier_generators, C._reidemeister_relators + gens, rels = simplify_presentation(gens, rels, change_gens=True) + + C.schreier_generators = tuple(gens) + C.reidemeister_relators = tuple(rels) + + if homomorphism: + _gens = [C._schreier_gen_elem[str(gen)] for gen in gens] + return C.schreier_generators, C.reidemeister_relators, _gens + + return C.schreier_generators, C.reidemeister_relators + + +FpGroupElement = FreeGroupElement diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/free_groups.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/free_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..2ec85f4fac73fba95d4644a37ecb27140e45a4c5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/free_groups.py @@ -0,0 +1,1360 @@ +from __future__ import annotations + +from sympy.core import S +from sympy.core.expr import Expr +from sympy.core.symbol import Symbol, symbols as _symbols +from sympy.core.sympify import CantSympify +from sympy.printing.defaults import DefaultPrinting +from sympy.utilities import public +from sympy.utilities.iterables import flatten, is_sequence +from sympy.utilities.magic import pollute +from sympy.utilities.misc import as_int + + +@public +def free_group(symbols): + """Construct a free group returning ``(FreeGroup, (f_0, f_1, ..., f_(n-1))``. + + Parameters + ========== + + symbols : str, Symbol/Expr or sequence of str, Symbol/Expr (may be empty) + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> F, x, y, z = free_group("x, y, z") + >>> F + + >>> x**2*y**-1 + x**2*y**-1 + >>> type(_) + + + """ + _free_group = FreeGroup(symbols) + return (_free_group,) + tuple(_free_group.generators) + +@public +def xfree_group(symbols): + """Construct a free group returning ``(FreeGroup, (f_0, f_1, ..., f_(n-1)))``. + + Parameters + ========== + + symbols : str, Symbol/Expr or sequence of str, Symbol/Expr (may be empty) + + Examples + ======== + + >>> from sympy.combinatorics.free_groups import xfree_group + >>> F, (x, y, z) = xfree_group("x, y, z") + >>> F + + >>> y**2*x**-2*z**-1 + y**2*x**-2*z**-1 + >>> type(_) + + + """ + _free_group = FreeGroup(symbols) + return (_free_group, _free_group.generators) + +@public +def vfree_group(symbols): + """Construct a free group and inject ``f_0, f_1, ..., f_(n-1)`` as symbols + into the global namespace. + + Parameters + ========== + + symbols : str, Symbol/Expr or sequence of str, Symbol/Expr (may be empty) + + Examples + ======== + + >>> from sympy.combinatorics.free_groups import vfree_group + >>> vfree_group("x, y, z") + + >>> x**2*y**-2*z # noqa: F821 + x**2*y**-2*z + >>> type(_) + + + """ + _free_group = FreeGroup(symbols) + pollute([sym.name for sym in _free_group.symbols], _free_group.generators) + return _free_group + + +def _parse_symbols(symbols): + if not symbols: + return () + if isinstance(symbols, str): + return _symbols(symbols, seq=True) + elif isinstance(symbols, (Expr, FreeGroupElement)): + return (symbols,) + elif is_sequence(symbols): + if all(isinstance(s, str) for s in symbols): + return _symbols(symbols) + elif all(isinstance(s, Expr) for s in symbols): + return symbols + raise ValueError("The type of `symbols` must be one of the following: " + "a str, Symbol/Expr or a sequence of " + "one of these types") + + +############################################################################## +# FREE GROUP # +############################################################################## + +_free_group_cache: dict[int, FreeGroup] = {} + +class FreeGroup(DefaultPrinting): + """ + Free group with finite or infinite number of generators. Its input API + is that of a str, Symbol/Expr or a sequence of one of + these types (which may be empty) + + See Also + ======== + + sympy.polys.rings.PolyRing + + References + ========== + + .. [1] https://www.gap-system.org/Manuals/doc/ref/chap37.html + + .. [2] https://en.wikipedia.org/wiki/Free_group + + """ + is_associative = True + is_group = True + is_FreeGroup = True + is_PermutationGroup = False + relators: list[Expr] = [] + + def __new__(cls, symbols): + symbols = tuple(_parse_symbols(symbols)) + rank = len(symbols) + _hash = hash((cls.__name__, symbols, rank)) + obj = _free_group_cache.get(_hash) + + if obj is None: + obj = object.__new__(cls) + obj._hash = _hash + obj._rank = rank + # dtype method is used to create new instances of FreeGroupElement + obj.dtype = type("FreeGroupElement", (FreeGroupElement,), {"group": obj}) + obj.symbols = symbols + obj.generators = obj._generators() + obj._gens_set = set(obj.generators) + for symbol, generator in zip(obj.symbols, obj.generators): + if isinstance(symbol, Symbol): + name = symbol.name + if hasattr(obj, name): + setattr(obj, name, generator) + + _free_group_cache[_hash] = obj + + return obj + + def __getnewargs__(self): + """Return a tuple of arguments that must be passed to __new__ in order to support pickling this object.""" + return (self.symbols,) + + def __getstate__(self): + # Don't pickle any fields because they are regenerated within __new__ + return None + + def _generators(group): + """Returns the generators of the FreeGroup. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> F, x, y, z = free_group("x, y, z") + >>> F.generators + (x, y, z) + + """ + gens = [] + for sym in group.symbols: + elm = ((sym, 1),) + gens.append(group.dtype(elm)) + return tuple(gens) + + def clone(self, symbols=None): + return self.__class__(symbols or self.symbols) + + def __contains__(self, i): + """Return True if ``i`` is contained in FreeGroup.""" + if not isinstance(i, FreeGroupElement): + return False + group = i.group + return self == group + + def __hash__(self): + return self._hash + + def __len__(self): + return self.rank + + def __str__(self): + if self.rank > 30: + str_form = "" % self.rank + else: + str_form = "" + return str_form + + __repr__ = __str__ + + def __getitem__(self, index): + symbols = self.symbols[index] + return self.clone(symbols=symbols) + + def __eq__(self, other): + """No ``FreeGroup`` is equal to any "other" ``FreeGroup``. + """ + return self is other + + def index(self, gen): + """Return the index of the generator `gen` from ``(f_0, ..., f_(n-1))``. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> F, x, y = free_group("x, y") + >>> F.index(y) + 1 + >>> F.index(x) + 0 + + """ + if isinstance(gen, self.dtype): + return self.generators.index(gen) + else: + raise ValueError("expected a generator of Free Group %s, got %s" % (self, gen)) + + def order(self): + """Return the order of the free group. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> F, x, y = free_group("x, y") + >>> F.order() + oo + + >>> free_group("")[0].order() + 1 + + """ + if self.rank == 0: + return S.One + else: + return S.Infinity + + @property + def elements(self): + """ + Return the elements of the free group. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> (z,) = free_group("") + >>> z.elements + {} + + """ + if self.rank == 0: + # A set containing Identity element of `FreeGroup` self is returned + return {self.identity} + else: + raise ValueError("Group contains infinitely many elements" + ", hence cannot be represented") + + @property + def rank(self): + r""" + In group theory, the `rank` of a group `G`, denoted `G.rank`, + can refer to the smallest cardinality of a generating set + for G, that is + + \operatorname{rank}(G)=\min\{ |X|: X\subseteq G, \left\langle X\right\rangle =G\}. + + """ + return self._rank + + @property + def is_abelian(self): + """Returns if the group is Abelian. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> f, x, y, z = free_group("x y z") + >>> f.is_abelian + False + + """ + return self.rank in (0, 1) + + @property + def identity(self): + """Returns the identity element of free group.""" + return self.dtype() + + def contains(self, g): + """Tests if Free Group element ``g`` belong to self, ``G``. + + In mathematical terms any linear combination of generators + of a Free Group is contained in it. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> f, x, y, z = free_group("x y z") + >>> f.contains(x**3*y**2) + True + + """ + if not isinstance(g, FreeGroupElement): + return False + elif self != g.group: + return False + else: + return True + + def center(self): + """Returns the center of the free group `self`.""" + return {self.identity} + + +############################################################################ +# FreeGroupElement # +############################################################################ + + +class FreeGroupElement(CantSympify, DefaultPrinting, tuple): + """Used to create elements of FreeGroup. It cannot be used directly to + create a free group element. It is called by the `dtype` method of the + `FreeGroup` class. + + """ + __slots__ = () + is_assoc_word = True + + def new(self, init): + return self.__class__(init) + + _hash = None + + def __hash__(self): + _hash = self._hash + if _hash is None: + self._hash = _hash = hash((self.group, frozenset(tuple(self)))) + return _hash + + def copy(self): + return self.new(self) + + @property + def is_identity(self): + return not self.array_form + + @property + def array_form(self): + """ + SymPy provides two different internal kinds of representation + of associative words. The first one is called the `array_form` + which is a tuple containing `tuples` as its elements, where the + size of each tuple is two. At the first position the tuple + contains the `symbol-generator`, while at the second position + of tuple contains the exponent of that generator at the position. + Since elements (i.e. words) do not commute, the indexing of tuple + makes that property to stay. + + The structure in ``array_form`` of ``FreeGroupElement`` is of form: + + ``( ( symbol_of_gen, exponent ), ( , ), ... ( , ) )`` + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> f, x, y, z = free_group("x y z") + >>> (x*z).array_form + ((x, 1), (z, 1)) + >>> (x**2*z*y*x**2).array_form + ((x, 2), (z, 1), (y, 1), (x, 2)) + + See Also + ======== + + letter_repr + + """ + return tuple(self) + + @property + def letter_form(self): + """ + The letter representation of a ``FreeGroupElement`` is a tuple + of generator symbols, with each entry corresponding to a group + generator. Inverses of the generators are represented by + negative generator symbols. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> f, a, b, c, d = free_group("a b c d") + >>> (a**3).letter_form + (a, a, a) + >>> (a**2*d**-2*a*b**-4).letter_form + (a, a, -d, -d, a, -b, -b, -b, -b) + >>> (a**-2*b**3*d).letter_form + (-a, -a, b, b, b, d) + + See Also + ======== + + array_form + + """ + return tuple(flatten([(i,)*j if j > 0 else (-i,)*(-j) + for i, j in self.array_form])) + + def __getitem__(self, i): + group = self.group + r = self.letter_form[i] + if r.is_Symbol: + return group.dtype(((r, 1),)) + else: + return group.dtype(((-r, -1),)) + + def index(self, gen): + if len(gen) != 1: + raise ValueError() + return (self.letter_form).index(gen.letter_form[0]) + + @property + def letter_form_elm(self): + """ + """ + group = self.group + r = self.letter_form + return [group.dtype(((elm,1),)) if elm.is_Symbol \ + else group.dtype(((-elm,-1),)) for elm in r] + + @property + def ext_rep(self): + """This is called the External Representation of ``FreeGroupElement`` + """ + return tuple(flatten(self.array_form)) + + def __contains__(self, gen): + return gen.array_form[0][0] in tuple([r[0] for r in self.array_form]) + + def __str__(self): + if self.is_identity: + return "" + + str_form = "" + array_form = self.array_form + for i in range(len(array_form)): + if i == len(array_form) - 1: + if array_form[i][1] == 1: + str_form += str(array_form[i][0]) + else: + str_form += str(array_form[i][0]) + \ + "**" + str(array_form[i][1]) + else: + if array_form[i][1] == 1: + str_form += str(array_form[i][0]) + "*" + else: + str_form += str(array_form[i][0]) + \ + "**" + str(array_form[i][1]) + "*" + return str_form + + __repr__ = __str__ + + def __pow__(self, n): + n = as_int(n) + result = self.group.identity + if n == 0: + return result + if n < 0: + n = -n + x = self.inverse() + else: + x = self + while True: + if n % 2: + result *= x + n >>= 1 + if not n: + break + x *= x + return result + + def __mul__(self, other): + """Returns the product of elements belonging to the same ``FreeGroup``. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> f, x, y, z = free_group("x y z") + >>> x*y**2*y**-4 + x*y**-2 + >>> z*y**-2 + z*y**-2 + >>> x**2*y*y**-1*x**-2 + + + """ + group = self.group + if not isinstance(other, group.dtype): + raise TypeError("only FreeGroup elements of same FreeGroup can " + "be multiplied") + if self.is_identity: + return other + if other.is_identity: + return self + r = list(self.array_form + other.array_form) + zero_mul_simp(r, len(self.array_form) - 1) + return group.dtype(tuple(r)) + + def __truediv__(self, other): + group = self.group + if not isinstance(other, group.dtype): + raise TypeError("only FreeGroup elements of same FreeGroup can " + "be multiplied") + return self*(other.inverse()) + + def __rtruediv__(self, other): + group = self.group + if not isinstance(other, group.dtype): + raise TypeError("only FreeGroup elements of same FreeGroup can " + "be multiplied") + return other*(self.inverse()) + + def __add__(self, other): + return NotImplemented + + def inverse(self): + """ + Returns the inverse of a ``FreeGroupElement`` element + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> f, x, y, z = free_group("x y z") + >>> x.inverse() + x**-1 + >>> (x*y).inverse() + y**-1*x**-1 + + """ + group = self.group + r = tuple([(i, -j) for i, j in self.array_form[::-1]]) + return group.dtype(r) + + def order(self): + """Find the order of a ``FreeGroupElement``. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> f, x, y = free_group("x y") + >>> (x**2*y*y**-1*x**-2).order() + 1 + + """ + if self.is_identity: + return S.One + else: + return S.Infinity + + def commutator(self, other): + """ + Return the commutator of `self` and `x`: ``~x*~self*x*self`` + + """ + group = self.group + if not isinstance(other, group.dtype): + raise ValueError("commutator of only FreeGroupElement of the same " + "FreeGroup exists") + else: + return self.inverse()*other.inverse()*self*other + + def eliminate_words(self, words, _all=False, inverse=True): + ''' + Replace each subword from the dictionary `words` by words[subword]. + If words is a list, replace the words by the identity. + + ''' + again = True + new = self + if isinstance(words, dict): + while again: + again = False + for sub in words: + prev = new + new = new.eliminate_word(sub, words[sub], _all=_all, inverse=inverse) + if new != prev: + again = True + else: + while again: + again = False + for sub in words: + prev = new + new = new.eliminate_word(sub, _all=_all, inverse=inverse) + if new != prev: + again = True + return new + + def eliminate_word(self, gen, by=None, _all=False, inverse=True): + """ + For an associative word `self`, a subword `gen`, and an associative + word `by` (identity by default), return the associative word obtained by + replacing each occurrence of `gen` in `self` by `by`. If `_all = True`, + the occurrences of `gen` that may appear after the first substitution will + also be replaced and so on until no occurrences are found. This might not + always terminate (e.g. `(x).eliminate_word(x, x**2, _all=True)`). + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> f, x, y = free_group("x y") + >>> w = x**5*y*x**2*y**-4*x + >>> w.eliminate_word( x, x**2 ) + x**10*y*x**4*y**-4*x**2 + >>> w.eliminate_word( x, y**-1 ) + y**-11 + >>> w.eliminate_word(x**5) + y*x**2*y**-4*x + >>> w.eliminate_word(x*y, y) + x**4*y*x**2*y**-4*x + + See Also + ======== + substituted_word + + """ + if by is None: + by = self.group.identity + if self.is_independent(gen) or gen == by: + return self + if gen == self: + return by + if gen**-1 == by: + _all = False + word = self + l = len(gen) + + try: + i = word.subword_index(gen) + k = 1 + except ValueError: + if not inverse: + return word + try: + i = word.subword_index(gen**-1) + k = -1 + except ValueError: + return word + + word = word.subword(0, i)*by**k*word.subword(i+l, len(word)).eliminate_word(gen, by) + + if _all: + return word.eliminate_word(gen, by, _all=True, inverse=inverse) + else: + return word + + def __len__(self): + """ + For an associative word `self`, returns the number of letters in it. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> f, a, b = free_group("a b") + >>> w = a**5*b*a**2*b**-4*a + >>> len(w) + 13 + >>> len(a**17) + 17 + >>> len(w**0) + 0 + + """ + return sum(abs(j) for (i, j) in self) + + def __eq__(self, other): + """ + Two associative words are equal if they are words over the + same alphabet and if they are sequences of the same letters. + This is equivalent to saying that the external representations + of the words are equal. + There is no "universal" empty word, every alphabet has its own + empty word. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> f, swapnil0, swapnil1 = free_group("swapnil0 swapnil1") + >>> f + + >>> g, swap0, swap1 = free_group("swap0 swap1") + >>> g + + + >>> swapnil0 == swapnil1 + False + >>> swapnil0*swapnil1 == swapnil1/swapnil1*swapnil0*swapnil1 + True + >>> swapnil0*swapnil1 == swapnil1*swapnil0 + False + >>> swapnil1**0 == swap0**0 + False + + """ + group = self.group + if not isinstance(other, group.dtype): + return False + return tuple.__eq__(self, other) + + def __lt__(self, other): + """ + The ordering of associative words is defined by length and + lexicography (this ordering is called short-lex ordering), that + is, shorter words are smaller than longer words, and words of the + same length are compared w.r.t. the lexicographical ordering induced + by the ordering of generators. Generators are sorted according + to the order in which they were created. If the generators are + invertible then each generator `g` is larger than its inverse `g^{-1}`, + and `g^{-1}` is larger than every generator that is smaller than `g`. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> f, a, b = free_group("a b") + >>> b < a + False + >>> a < a.inverse() + False + + """ + group = self.group + if not isinstance(other, group.dtype): + raise TypeError("only FreeGroup elements of same FreeGroup can " + "be compared") + l = len(self) + m = len(other) + # implement lenlex order + if l < m: + return True + elif l > m: + return False + for i in range(l): + a = self[i].array_form[0] + b = other[i].array_form[0] + p = group.symbols.index(a[0]) + q = group.symbols.index(b[0]) + if p < q: + return True + elif p > q: + return False + elif a[1] < b[1]: + return True + elif a[1] > b[1]: + return False + return False + + def __le__(self, other): + return (self == other or self < other) + + def __gt__(self, other): + """ + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> f, x, y, z = free_group("x y z") + >>> y**2 > x**2 + True + >>> y*z > z*y + False + >>> x > x.inverse() + True + + """ + group = self.group + if not isinstance(other, group.dtype): + raise TypeError("only FreeGroup elements of same FreeGroup can " + "be compared") + return not self <= other + + def __ge__(self, other): + return not self < other + + def exponent_sum(self, gen): + """ + For an associative word `self` and a generator or inverse of generator + `gen`, ``exponent_sum`` returns the number of times `gen` appears in + `self` minus the number of times its inverse appears in `self`. If + neither `gen` nor its inverse occur in `self` then 0 is returned. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> F, x, y = free_group("x, y") + >>> w = x**2*y**3 + >>> w.exponent_sum(x) + 2 + >>> w.exponent_sum(x**-1) + -2 + >>> w = x**2*y**4*x**-3 + >>> w.exponent_sum(x) + -1 + + See Also + ======== + + generator_count + + """ + if len(gen) != 1: + raise ValueError("gen must be a generator or inverse of a generator") + s = gen.array_form[0] + return s[1]*sum(i[1] for i in self.array_form if i[0] == s[0]) + + def generator_count(self, gen): + """ + For an associative word `self` and a generator `gen`, + ``generator_count`` returns the multiplicity of generator + `gen` in `self`. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> F, x, y = free_group("x, y") + >>> w = x**2*y**3 + >>> w.generator_count(x) + 2 + >>> w = x**2*y**4*x**-3 + >>> w.generator_count(x) + 5 + + See Also + ======== + + exponent_sum + + """ + if len(gen) != 1 or gen.array_form[0][1] < 0: + raise ValueError("gen must be a generator") + s = gen.array_form[0] + return s[1]*sum(abs(i[1]) for i in self.array_form if i[0] == s[0]) + + def subword(self, from_i, to_j, strict=True): + """ + For an associative word `self` and two positive integers `from_i` and + `to_j`, `subword` returns the subword of `self` that begins at position + `from_i` and ends at `to_j - 1`, indexing is done with origin 0. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> f, a, b = free_group("a b") + >>> w = a**5*b*a**2*b**-4*a + >>> w.subword(2, 6) + a**3*b + + """ + group = self.group + if not strict: + from_i = max(from_i, 0) + to_j = min(len(self), to_j) + if from_i < 0 or to_j > len(self): + raise ValueError("`from_i`, `to_j` must be positive and no greater than " + "the length of associative word") + if to_j <= from_i: + return group.identity + else: + letter_form = self.letter_form[from_i: to_j] + array_form = letter_form_to_array_form(letter_form, group) + return group.dtype(array_form) + + def subword_index(self, word, start = 0): + ''' + Find the index of `word` in `self`. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> f, a, b = free_group("a b") + >>> w = a**2*b*a*b**3 + >>> w.subword_index(a*b*a*b) + 1 + + ''' + l = len(word) + self_lf = self.letter_form + word_lf = word.letter_form + index = None + for i in range(start,len(self_lf)-l+1): + if self_lf[i:i+l] == word_lf: + index = i + break + if index is not None: + return index + else: + raise ValueError("The given word is not a subword of self") + + def is_dependent(self, word): + """ + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> F, x, y = free_group("x, y") + >>> (x**4*y**-3).is_dependent(x**4*y**-2) + True + >>> (x**2*y**-1).is_dependent(x*y) + False + >>> (x*y**2*x*y**2).is_dependent(x*y**2) + True + >>> (x**12).is_dependent(x**-4) + True + + See Also + ======== + + is_independent + + """ + try: + return self.subword_index(word) is not None + except ValueError: + pass + try: + return self.subword_index(word**-1) is not None + except ValueError: + return False + + def is_independent(self, word): + """ + + See Also + ======== + + is_dependent + + """ + return not self.is_dependent(word) + + def contains_generators(self): + """ + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> F, x, y, z = free_group("x, y, z") + >>> (x**2*y**-1).contains_generators() + {x, y} + >>> (x**3*z).contains_generators() + {x, z} + + """ + group = self.group + gens = {group.dtype(((syllable[0], 1),)) for syllable in self.array_form} + return gens + + def cyclic_subword(self, from_i, to_j): + group = self.group + l = len(self) + letter_form = self.letter_form + period1 = int(from_i/l) + if from_i >= l: + from_i -= l*period1 + to_j -= l*period1 + diff = to_j - from_i + word = letter_form[from_i: to_j] + period2 = int(to_j/l) - 1 + word += letter_form*period2 + letter_form[:diff-l+from_i-l*period2] + word = letter_form_to_array_form(word, group) + return group.dtype(word) + + def cyclic_conjugates(self): + """Returns a words which are cyclic to the word `self`. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> F, x, y = free_group("x, y") + >>> w = x*y*x*y*x + >>> w.cyclic_conjugates() + {x*y*x**2*y, x**2*y*x*y, y*x*y*x**2, y*x**2*y*x, x*y*x*y*x} + >>> s = x*y*x**2*y*x + >>> s.cyclic_conjugates() + {x**2*y*x**2*y, y*x**2*y*x**2, x*y*x**2*y*x} + + References + ========== + + .. [1] https://planetmath.org/cyclicpermutation + + """ + return {self.cyclic_subword(i, i+len(self)) for i in range(len(self))} + + def is_cyclic_conjugate(self, w): + """ + Checks whether words ``self``, ``w`` are cyclic conjugates. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> F, x, y = free_group("x, y") + >>> w1 = x**2*y**5 + >>> w2 = x*y**5*x + >>> w1.is_cyclic_conjugate(w2) + True + >>> w3 = x**-1*y**5*x**-1 + >>> w3.is_cyclic_conjugate(w2) + False + + """ + l1 = len(self) + l2 = len(w) + if l1 != l2: + return False + w1 = self.identity_cyclic_reduction() + w2 = w.identity_cyclic_reduction() + letter1 = w1.letter_form + letter2 = w2.letter_form + str1 = ' '.join(map(str, letter1)) + str2 = ' '.join(map(str, letter2)) + if len(str1) != len(str2): + return False + + return str1 in str2 + ' ' + str2 + + def number_syllables(self): + """Returns the number of syllables of the associative word `self`. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> f, swapnil0, swapnil1 = free_group("swapnil0 swapnil1") + >>> (swapnil1**3*swapnil0*swapnil1**-1).number_syllables() + 3 + + """ + return len(self.array_form) + + def exponent_syllable(self, i): + """ + Returns the exponent of the `i`-th syllable of the associative word + `self`. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> f, a, b = free_group("a b") + >>> w = a**5*b*a**2*b**-4*a + >>> w.exponent_syllable( 2 ) + 2 + + """ + return self.array_form[i][1] + + def generator_syllable(self, i): + """ + Returns the symbol of the generator that is involved in the + i-th syllable of the associative word `self`. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> f, a, b = free_group("a b") + >>> w = a**5*b*a**2*b**-4*a + >>> w.generator_syllable( 3 ) + b + + """ + return self.array_form[i][0] + + def sub_syllables(self, from_i, to_j): + """ + `sub_syllables` returns the subword of the associative word `self` that + consists of syllables from positions `from_to` to `to_j`, where + `from_to` and `to_j` must be positive integers and indexing is done + with origin 0. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> f, a, b = free_group("a, b") + >>> w = a**5*b*a**2*b**-4*a + >>> w.sub_syllables(1, 2) + b + >>> w.sub_syllables(3, 3) + + + """ + if not isinstance(from_i, int) or not isinstance(to_j, int): + raise ValueError("both arguments should be integers") + group = self.group + if to_j <= from_i: + return group.identity + else: + r = tuple(self.array_form[from_i: to_j]) + return group.dtype(r) + + def substituted_word(self, from_i, to_j, by): + """ + Returns the associative word obtained by replacing the subword of + `self` that begins at position `from_i` and ends at position `to_j - 1` + by the associative word `by`. `from_i` and `to_j` must be positive + integers, indexing is done with origin 0. In other words, + `w.substituted_word(w, from_i, to_j, by)` is the product of the three + words: `w.subword(0, from_i)`, `by`, and + `w.subword(to_j len(w))`. + + See Also + ======== + + eliminate_word + + """ + lw = len(self) + if from_i >= to_j or from_i > lw or to_j > lw: + raise ValueError("values should be within bounds") + + # otherwise there are four possibilities + + # first if from=1 and to=lw then + if from_i == 0 and to_j == lw: + return by + elif from_i == 0: # second if from_i=1 (and to_j < lw) then + return by*self.subword(to_j, lw) + elif to_j == lw: # third if to_j=1 (and from_i > 1) then + return self.subword(0, from_i)*by + else: # finally + return self.subword(0, from_i)*by*self.subword(to_j, lw) + + def is_cyclically_reduced(self): + r"""Returns whether the word is cyclically reduced or not. + A word is cyclically reduced if by forming the cycle of the + word, the word is not reduced, i.e a word w = `a_1 ... a_n` + is called cyclically reduced if `a_1 \ne a_n^{-1}`. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> F, x, y = free_group("x, y") + >>> (x**2*y**-1*x**-1).is_cyclically_reduced() + False + >>> (y*x**2*y**2).is_cyclically_reduced() + True + + """ + if not self: + return True + return self[0] != self[-1]**-1 + + def identity_cyclic_reduction(self): + """Return a unique cyclically reduced version of the word. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> F, x, y = free_group("x, y") + >>> (x**2*y**2*x**-1).identity_cyclic_reduction() + x*y**2 + >>> (x**-3*y**-1*x**5).identity_cyclic_reduction() + x**2*y**-1 + + References + ========== + + .. [1] https://planetmath.org/cyclicallyreduced + + """ + word = self.copy() + group = self.group + while not word.is_cyclically_reduced(): + exp1 = word.exponent_syllable(0) + exp2 = word.exponent_syllable(-1) + r = exp1 + exp2 + if r == 0: + rep = word.array_form[1: word.number_syllables() - 1] + else: + rep = ((word.generator_syllable(0), exp1 + exp2),) + \ + word.array_form[1: word.number_syllables() - 1] + word = group.dtype(rep) + return word + + def cyclic_reduction(self, removed=False): + """Return a cyclically reduced version of the word. Unlike + `identity_cyclic_reduction`, this will not cyclically permute + the reduced word - just remove the "unreduced" bits on either + side of it. Compare the examples with those of + `identity_cyclic_reduction`. + + When `removed` is `True`, return a tuple `(word, r)` where + self `r` is such that before the reduction the word was either + `r*word*r**-1`. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> F, x, y = free_group("x, y") + >>> (x**2*y**2*x**-1).cyclic_reduction() + x*y**2 + >>> (x**-3*y**-1*x**5).cyclic_reduction() + y**-1*x**2 + >>> (x**-3*y**-1*x**5).cyclic_reduction(removed=True) + (y**-1*x**2, x**-3) + + """ + word = self.copy() + g = self.group.identity + while not word.is_cyclically_reduced(): + exp1 = abs(word.exponent_syllable(0)) + exp2 = abs(word.exponent_syllable(-1)) + exp = min(exp1, exp2) + start = word[0]**abs(exp) + end = word[-1]**abs(exp) + word = start**-1*word*end**-1 + g = g*start + if removed: + return word, g + return word + + def power_of(self, other): + ''' + Check if `self == other**n` for some integer n. + + Examples + ======== + + >>> from sympy.combinatorics import free_group + >>> F, x, y = free_group("x, y") + >>> ((x*y)**2).power_of(x*y) + True + >>> (x**-3*y**-2*x**3).power_of(x**-3*y*x**3) + True + + ''' + if self.is_identity: + return True + + l = len(other) + if l == 1: + # self has to be a power of one generator + gens = self.contains_generators() + s = other in gens or other**-1 in gens + return len(gens) == 1 and s + + # if self is not cyclically reduced and it is a power of other, + # other isn't cyclically reduced and the parts removed during + # their reduction must be equal + reduced, r1 = self.cyclic_reduction(removed=True) + if not r1.is_identity: + other, r2 = other.cyclic_reduction(removed=True) + if r1 == r2: + return reduced.power_of(other) + return False + + if len(self) < l or len(self) % l: + return False + + prefix = self.subword(0, l) + if prefix == other or prefix**-1 == other: + rest = self.subword(l, len(self)) + return rest.power_of(other) + return False + + +def letter_form_to_array_form(array_form, group): + """ + This method converts a list given with possible repetitions of elements in + it. It returns a new list such that repetitions of consecutive elements is + removed and replace with a tuple element of size two such that the first + index contains `value` and the second index contains the number of + consecutive repetitions of `value`. + + """ + a = list(array_form[:]) + new_array = [] + n = 1 + symbols = group.symbols + for i in range(len(a)): + if i == len(a) - 1: + if a[i] == a[i - 1]: + if (-a[i]) in symbols: + new_array.append((-a[i], -n)) + else: + new_array.append((a[i], n)) + else: + if (-a[i]) in symbols: + new_array.append((-a[i], -1)) + else: + new_array.append((a[i], 1)) + return new_array + elif a[i] == a[i + 1]: + n += 1 + else: + if (-a[i]) in symbols: + new_array.append((-a[i], -n)) + else: + new_array.append((a[i], n)) + n = 1 + + +def zero_mul_simp(l, index): + """Used to combine two reduced words.""" + while index >=0 and index < len(l) - 1 and l[index][0] == l[index + 1][0]: + exp = l[index][1] + l[index + 1][1] + base = l[index][0] + l[index] = (base, exp) + del l[index + 1] + if l[index][1] == 0: + del l[index] + index -= 1 diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/galois.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/galois.py new file mode 100644 index 0000000000000000000000000000000000000000..f8ff89886283903e116bf40e1be6f6131386e802 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/galois.py @@ -0,0 +1,611 @@ +r""" +Construct transitive subgroups of symmetric groups, useful in Galois theory. + +Besides constructing instances of the :py:class:`~.PermutationGroup` class to +represent the transitive subgroups of $S_n$ for small $n$, this module provides +*names* for these groups. + +In some applications, it may be preferable to know the name of a group, +rather than receive an instance of the :py:class:`~.PermutationGroup` +class, and then have to do extra work to determine which group it is, by +checking various properties. + +Names are instances of ``Enum`` classes defined in this module. With a name in +hand, the name's ``get_perm_group`` method can then be used to retrieve a +:py:class:`~.PermutationGroup`. + +The names used for groups in this module are taken from [1]. + +References +========== + +.. [1] Cohen, H. *A Course in Computational Algebraic Number Theory*. + +""" + +from collections import defaultdict +from enum import Enum +import itertools + +from sympy.combinatorics.named_groups import ( + SymmetricGroup, AlternatingGroup, CyclicGroup, DihedralGroup, + set_symmetric_group_properties, set_alternating_group_properties, +) +from sympy.combinatorics.perm_groups import PermutationGroup +from sympy.combinatorics.permutations import Permutation + + +class S1TransitiveSubgroups(Enum): + """ + Names for the transitive subgroups of S1. + """ + S1 = "S1" + + def get_perm_group(self): + return SymmetricGroup(1) + + +class S2TransitiveSubgroups(Enum): + """ + Names for the transitive subgroups of S2. + """ + S2 = "S2" + + def get_perm_group(self): + return SymmetricGroup(2) + + +class S3TransitiveSubgroups(Enum): + """ + Names for the transitive subgroups of S3. + """ + A3 = "A3" + S3 = "S3" + + def get_perm_group(self): + if self == S3TransitiveSubgroups.A3: + return AlternatingGroup(3) + elif self == S3TransitiveSubgroups.S3: + return SymmetricGroup(3) + + +class S4TransitiveSubgroups(Enum): + """ + Names for the transitive subgroups of S4. + """ + C4 = "C4" + V = "V" + D4 = "D4" + A4 = "A4" + S4 = "S4" + + def get_perm_group(self): + if self == S4TransitiveSubgroups.C4: + return CyclicGroup(4) + elif self == S4TransitiveSubgroups.V: + return four_group() + elif self == S4TransitiveSubgroups.D4: + return DihedralGroup(4) + elif self == S4TransitiveSubgroups.A4: + return AlternatingGroup(4) + elif self == S4TransitiveSubgroups.S4: + return SymmetricGroup(4) + + +class S5TransitiveSubgroups(Enum): + """ + Names for the transitive subgroups of S5. + """ + C5 = "C5" + D5 = "D5" + M20 = "M20" + A5 = "A5" + S5 = "S5" + + def get_perm_group(self): + if self == S5TransitiveSubgroups.C5: + return CyclicGroup(5) + elif self == S5TransitiveSubgroups.D5: + return DihedralGroup(5) + elif self == S5TransitiveSubgroups.M20: + return M20() + elif self == S5TransitiveSubgroups.A5: + return AlternatingGroup(5) + elif self == S5TransitiveSubgroups.S5: + return SymmetricGroup(5) + + +class S6TransitiveSubgroups(Enum): + """ + Names for the transitive subgroups of S6. + """ + C6 = "C6" + S3 = "S3" + D6 = "D6" + A4 = "A4" + G18 = "G18" + A4xC2 = "A4 x C2" + S4m = "S4-" + S4p = "S4+" + G36m = "G36-" + G36p = "G36+" + S4xC2 = "S4 x C2" + PSL2F5 = "PSL2(F5)" + G72 = "G72" + PGL2F5 = "PGL2(F5)" + A6 = "A6" + S6 = "S6" + + def get_perm_group(self): + if self == S6TransitiveSubgroups.C6: + return CyclicGroup(6) + elif self == S6TransitiveSubgroups.S3: + return S3_in_S6() + elif self == S6TransitiveSubgroups.D6: + return DihedralGroup(6) + elif self == S6TransitiveSubgroups.A4: + return A4_in_S6() + elif self == S6TransitiveSubgroups.G18: + return G18() + elif self == S6TransitiveSubgroups.A4xC2: + return A4xC2() + elif self == S6TransitiveSubgroups.S4m: + return S4m() + elif self == S6TransitiveSubgroups.S4p: + return S4p() + elif self == S6TransitiveSubgroups.G36m: + return G36m() + elif self == S6TransitiveSubgroups.G36p: + return G36p() + elif self == S6TransitiveSubgroups.S4xC2: + return S4xC2() + elif self == S6TransitiveSubgroups.PSL2F5: + return PSL2F5() + elif self == S6TransitiveSubgroups.G72: + return G72() + elif self == S6TransitiveSubgroups.PGL2F5: + return PGL2F5() + elif self == S6TransitiveSubgroups.A6: + return AlternatingGroup(6) + elif self == S6TransitiveSubgroups.S6: + return SymmetricGroup(6) + + +def four_group(): + """ + Return a representation of the Klein four-group as a transitive subgroup + of S4. + """ + return PermutationGroup( + Permutation(0, 1)(2, 3), + Permutation(0, 2)(1, 3) + ) + + +def M20(): + """ + Return a representation of the metacyclic group M20, a transitive subgroup + of S5 that is one of the possible Galois groups for polys of degree 5. + + Notes + ===== + + See [1], Page 323. + + """ + G = PermutationGroup(Permutation(0, 1, 2, 3, 4), Permutation(1, 2, 4, 3)) + G._degree = 5 + G._order = 20 + G._is_transitive = True + G._is_sym = False + G._is_alt = False + G._is_cyclic = False + G._is_dihedral = False + return G + + +def S3_in_S6(): + """ + Return a representation of S3 as a transitive subgroup of S6. + + Notes + ===== + + The representation is found by viewing the group as the symmetries of a + triangular prism. + + """ + G = PermutationGroup(Permutation(0, 1, 2)(3, 4, 5), Permutation(0, 3)(2, 4)(1, 5)) + set_symmetric_group_properties(G, 3, 6) + return G + + +def A4_in_S6(): + """ + Return a representation of A4 as a transitive subgroup of S6. + + Notes + ===== + + This was computed using :py:func:`~.find_transitive_subgroups_of_S6`. + + """ + G = PermutationGroup(Permutation(0, 4, 5)(1, 3, 2), Permutation(0, 1, 2)(3, 5, 4)) + set_alternating_group_properties(G, 4, 6) + return G + + +def S4m(): + """ + Return a representation of the S4- transitive subgroup of S6. + + Notes + ===== + + This was computed using :py:func:`~.find_transitive_subgroups_of_S6`. + + """ + G = PermutationGroup(Permutation(1, 4, 5, 3), Permutation(0, 4)(1, 5)(2, 3)) + set_symmetric_group_properties(G, 4, 6) + return G + + +def S4p(): + """ + Return a representation of the S4+ transitive subgroup of S6. + + Notes + ===== + + This was computed using :py:func:`~.find_transitive_subgroups_of_S6`. + + """ + G = PermutationGroup(Permutation(0, 2, 4, 1)(3, 5), Permutation(0, 3)(4, 5)) + set_symmetric_group_properties(G, 4, 6) + return G + + +def A4xC2(): + """ + Return a representation of the (A4 x C2) transitive subgroup of S6. + + Notes + ===== + + This was computed using :py:func:`~.find_transitive_subgroups_of_S6`. + + """ + return PermutationGroup( + Permutation(0, 4, 5)(1, 3, 2), Permutation(0, 1, 2)(3, 5, 4), + Permutation(5)(2, 4)) + + +def S4xC2(): + """ + Return a representation of the (S4 x C2) transitive subgroup of S6. + + Notes + ===== + + This was computed using :py:func:`~.find_transitive_subgroups_of_S6`. + + """ + return PermutationGroup( + Permutation(1, 4, 5, 3), Permutation(0, 4)(1, 5)(2, 3), + Permutation(1, 4)(3, 5)) + + +def G18(): + """ + Return a representation of the group G18, a transitive subgroup of S6 + isomorphic to the semidirect product of C3^2 with C2. + + Notes + ===== + + This was computed using :py:func:`~.find_transitive_subgroups_of_S6`. + + """ + return PermutationGroup( + Permutation(5)(0, 1, 2), Permutation(3, 4, 5), + Permutation(0, 4)(1, 5)(2, 3)) + + +def G36m(): + """ + Return a representation of the group G36-, a transitive subgroup of S6 + isomorphic to the semidirect product of C3^2 with C2^2. + + Notes + ===== + + This was computed using :py:func:`~.find_transitive_subgroups_of_S6`. + + """ + return PermutationGroup( + Permutation(5)(0, 1, 2), Permutation(3, 4, 5), + Permutation(1, 2)(3, 5), Permutation(0, 4)(1, 5)(2, 3)) + + +def G36p(): + """ + Return a representation of the group G36+, a transitive subgroup of S6 + isomorphic to the semidirect product of C3^2 with C4. + + Notes + ===== + + This was computed using :py:func:`~.find_transitive_subgroups_of_S6`. + + """ + return PermutationGroup( + Permutation(5)(0, 1, 2), Permutation(3, 4, 5), + Permutation(0, 5, 2, 3)(1, 4)) + + +def G72(): + """ + Return a representation of the group G72, a transitive subgroup of S6 + isomorphic to the semidirect product of C3^2 with D4. + + Notes + ===== + + See [1], Page 325. + + """ + return PermutationGroup( + Permutation(5)(0, 1, 2), + Permutation(0, 4, 1, 3)(2, 5), Permutation(0, 3)(1, 4)(2, 5)) + + +def PSL2F5(): + r""" + Return a representation of the group $PSL_2(\mathbb{F}_5)$, as a transitive + subgroup of S6, isomorphic to $A_5$. + + Notes + ===== + + This was computed using :py:func:`~.find_transitive_subgroups_of_S6`. + + """ + G = PermutationGroup( + Permutation(0, 4, 5)(1, 3, 2), Permutation(0, 4, 3, 1, 5)) + set_alternating_group_properties(G, 5, 6) + return G + + +def PGL2F5(): + r""" + Return a representation of the group $PGL_2(\mathbb{F}_5)$, as a transitive + subgroup of S6, isomorphic to $S_5$. + + Notes + ===== + + See [1], Page 325. + + """ + G = PermutationGroup( + Permutation(0, 1, 2, 3, 4), Permutation(0, 5)(1, 2)(3, 4)) + set_symmetric_group_properties(G, 5, 6) + return G + + +def find_transitive_subgroups_of_S6(*targets, print_report=False): + r""" + Search for certain transitive subgroups of $S_6$. + + The symmetric group $S_6$ has 16 different transitive subgroups, up to + conjugacy. Some are more easily constructed than others. For example, the + dihedral group $D_6$ is immediately found, but it is not at all obvious how + to realize $S_4$ or $S_5$ *transitively* within $S_6$. + + In some cases there are well-known constructions that can be used. For + example, $S_5$ is isomorphic to $PGL_2(\mathbb{F}_5)$, which acts in a + natural way on the projective line $P^1(\mathbb{F}_5)$, a set of order 6. + + In absence of such special constructions however, we can simply search for + generators. For example, transitive instances of $A_4$ and $S_4$ can be + found within $S_6$ in this way. + + Once we are engaged in such searches, it may then be easier (if less + elegant) to find even those groups like $S_5$ that do have special + constructions, by mere search. + + This function locates generators for transitive instances in $S_6$ of the + following subgroups: + + * $A_4$ + * $S_4^-$ ($S_4$ not contained within $A_6$) + * $S_4^+$ ($S_4$ contained within $A_6$) + * $A_4 \times C_2$ + * $S_4 \times C_2$ + * $G_{18} = C_3^2 \rtimes C_2$ + * $G_{36}^- = C_3^2 \rtimes C_2^2$ + * $G_{36}^+ = C_3^2 \rtimes C_4$ + * $G_{72} = C_3^2 \rtimes D_4$ + * $A_5$ + * $S_5$ + + Note: Each of these groups also has a dedicated function in this module + that returns the group immediately, using generators that were found by + this search procedure. + + The search procedure serves as a record of how these generators were + found. Also, due to randomness in the generation of the elements of + permutation groups, it can be called again, in order to (probably) get + different generators for the same groups. + + Parameters + ========== + + targets : list of :py:class:`~.S6TransitiveSubgroups` values + The groups you want to find. + + print_report : bool (default False) + If True, print to stdout the generators found for each group. + + Returns + ======= + + dict + mapping each name in *targets* to the :py:class:`~.PermutationGroup` + that was found + + References + ========== + + .. [2] https://en.wikipedia.org/wiki/Projective_linear_group#Exceptional_isomorphisms + .. [3] https://en.wikipedia.org/wiki/Automorphisms_of_the_symmetric_and_alternating_groups#PGL%282,5%29 + + """ + def elts_by_order(G): + """Sort the elements of a group by their order. """ + elts = defaultdict(list) + for g in G.elements: + elts[g.order()].append(g) + return elts + + def order_profile(G, name=None): + """Determine how many elements a group has, of each order. """ + elts = elts_by_order(G) + profile = {o:len(e) for o, e in elts.items()} + if name: + print(f'{name}: ' + ' '.join(f'{len(profile[r])}@{r}' for r in sorted(profile.keys()))) + return profile + + S6 = SymmetricGroup(6) + A6 = AlternatingGroup(6) + S6_by_order = elts_by_order(S6) + + def search(existing_gens, needed_gen_orders, order, alt=None, profile=None, anti_profile=None): + """ + Find a transitive subgroup of S6. + + Parameters + ========== + + existing_gens : list of Permutation + Optionally empty list of generators that must be in the group. + + needed_gen_orders : list of positive int + Nonempty list of the orders of the additional generators that are + to be found. + + order: int + The order of the group being sought. + + alt: bool, None + If True, require the group to be contained in A6. + If False, require the group not to be contained in A6. + + profile : dict + If given, the group's order profile must equal this. + + anti_profile : dict + If given, the group's order profile must *not* equal this. + + """ + for gens in itertools.product(*[S6_by_order[n] for n in needed_gen_orders]): + if len(set(gens)) < len(gens): + continue + G = PermutationGroup(existing_gens + list(gens)) + if G.order() == order and G.is_transitive(): + if alt is not None and G.is_subgroup(A6) != alt: + continue + if profile and order_profile(G) != profile: + continue + if anti_profile and order_profile(G) == anti_profile: + continue + return G + + def match_known_group(G, alt=None): + needed = [g.order() for g in G.generators] + return search([], needed, G.order(), alt=alt, profile=order_profile(G)) + + found = {} + + def finish_up(name, G): + found[name] = G + if print_report: + print("=" * 40) + print(f"{name}:") + print(G.generators) + + if S6TransitiveSubgroups.A4 in targets or S6TransitiveSubgroups.A4xC2 in targets: + A4_in_S6 = match_known_group(AlternatingGroup(4)) + finish_up(S6TransitiveSubgroups.A4, A4_in_S6) + + if S6TransitiveSubgroups.S4m in targets or S6TransitiveSubgroups.S4xC2 in targets: + S4m_in_S6 = match_known_group(SymmetricGroup(4), alt=False) + finish_up(S6TransitiveSubgroups.S4m, S4m_in_S6) + + if S6TransitiveSubgroups.S4p in targets: + S4p_in_S6 = match_known_group(SymmetricGroup(4), alt=True) + finish_up(S6TransitiveSubgroups.S4p, S4p_in_S6) + + if S6TransitiveSubgroups.A4xC2 in targets: + A4xC2_in_S6 = search(A4_in_S6.generators, [2], 24, anti_profile=order_profile(SymmetricGroup(4))) + finish_up(S6TransitiveSubgroups.A4xC2, A4xC2_in_S6) + + if S6TransitiveSubgroups.S4xC2 in targets: + S4xC2_in_S6 = search(S4m_in_S6.generators, [2], 48) + finish_up(S6TransitiveSubgroups.S4xC2, S4xC2_in_S6) + + # For the normal factor N = C3^2 in any of the G_n subgroups, we take one + # obvious instance of C3^2 in S6: + N_gens = [Permutation(5)(0, 1, 2), Permutation(5)(3, 4, 5)] + + if S6TransitiveSubgroups.G18 in targets: + G18_in_S6 = search(N_gens, [2], 18) + finish_up(S6TransitiveSubgroups.G18, G18_in_S6) + + if S6TransitiveSubgroups.G36m in targets: + G36m_in_S6 = search(N_gens, [2, 2], 36, alt=False) + finish_up(S6TransitiveSubgroups.G36m, G36m_in_S6) + + if S6TransitiveSubgroups.G36p in targets: + G36p_in_S6 = search(N_gens, [4], 36, alt=True) + finish_up(S6TransitiveSubgroups.G36p, G36p_in_S6) + + if S6TransitiveSubgroups.G72 in targets: + G72_in_S6 = search(N_gens, [4, 2], 72) + finish_up(S6TransitiveSubgroups.G72, G72_in_S6) + + # The PSL2(F5) and PGL2(F5) subgroups are isomorphic to A5 and S5, resp. + + if S6TransitiveSubgroups.PSL2F5 in targets: + PSL2F5_in_S6 = match_known_group(AlternatingGroup(5)) + finish_up(S6TransitiveSubgroups.PSL2F5, PSL2F5_in_S6) + + if S6TransitiveSubgroups.PGL2F5 in targets: + PGL2F5_in_S6 = match_known_group(SymmetricGroup(5)) + finish_up(S6TransitiveSubgroups.PGL2F5, PGL2F5_in_S6) + + # There is little need to "search" for any of the groups C6, S3, D6, A6, + # or S6, since they all have obvious realizations within S6. However, we + # support them here just in case a random representation is desired. + + if S6TransitiveSubgroups.C6 in targets: + C6 = match_known_group(CyclicGroup(6)) + finish_up(S6TransitiveSubgroups.C6, C6) + + if S6TransitiveSubgroups.S3 in targets: + S3 = match_known_group(SymmetricGroup(3)) + finish_up(S6TransitiveSubgroups.S3, S3) + + if S6TransitiveSubgroups.D6 in targets: + D6 = match_known_group(DihedralGroup(6)) + finish_up(S6TransitiveSubgroups.D6, D6) + + if S6TransitiveSubgroups.A6 in targets: + A6 = match_known_group(A6) + finish_up(S6TransitiveSubgroups.A6, A6) + + if S6TransitiveSubgroups.S6 in targets: + S6 = match_known_group(S6) + finish_up(S6TransitiveSubgroups.S6, S6) + + return found diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/generators.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/generators.py new file mode 100644 index 0000000000000000000000000000000000000000..19e274a4c5b4bf0781855db6bc6bf499349fff31 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/generators.py @@ -0,0 +1,301 @@ +from sympy.combinatorics.permutations import Permutation +from sympy.core.symbol import symbols +from sympy.matrices import Matrix +from sympy.utilities.iterables import variations, rotate_left + + +def symmetric(n): + """ + Generates the symmetric group of order n, Sn. + + Examples + ======== + + >>> from sympy.combinatorics.generators import symmetric + >>> list(symmetric(3)) + [(2), (1 2), (2)(0 1), (0 1 2), (0 2 1), (0 2)] + """ + yield from (Permutation(perm) for perm in variations(range(n), n)) + + +def cyclic(n): + """ + Generates the cyclic group of order n, Cn. + + Examples + ======== + + >>> from sympy.combinatorics.generators import cyclic + >>> list(cyclic(5)) + [(4), (0 1 2 3 4), (0 2 4 1 3), + (0 3 1 4 2), (0 4 3 2 1)] + + See Also + ======== + + dihedral + """ + gen = list(range(n)) + for i in range(n): + yield Permutation(gen) + gen = rotate_left(gen, 1) + + +def alternating(n): + """ + Generates the alternating group of order n, An. + + Examples + ======== + + >>> from sympy.combinatorics.generators import alternating + >>> list(alternating(3)) + [(2), (0 1 2), (0 2 1)] + """ + for perm in variations(range(n), n): + p = Permutation(perm) + if p.is_even: + yield p + + +def dihedral(n): + """ + Generates the dihedral group of order 2n, Dn. + + The result is given as a subgroup of Sn, except for the special cases n=1 + (the group S2) and n=2 (the Klein 4-group) where that's not possible + and embeddings in S2 and S4 respectively are given. + + Examples + ======== + + >>> from sympy.combinatorics.generators import dihedral + >>> list(dihedral(3)) + [(2), (0 2), (0 1 2), (1 2), (0 2 1), (2)(0 1)] + + See Also + ======== + + cyclic + """ + if n == 1: + yield Permutation([0, 1]) + yield Permutation([1, 0]) + elif n == 2: + yield Permutation([0, 1, 2, 3]) + yield Permutation([1, 0, 3, 2]) + yield Permutation([2, 3, 0, 1]) + yield Permutation([3, 2, 1, 0]) + else: + gen = list(range(n)) + for i in range(n): + yield Permutation(gen) + yield Permutation(gen[::-1]) + gen = rotate_left(gen, 1) + + +def rubik_cube_generators(): + """Return the permutations of the 3x3 Rubik's cube, see + https://www.gap-system.org/Doc/Examples/rubik.html + """ + a = [ + [(1, 3, 8, 6), (2, 5, 7, 4), (9, 33, 25, 17), (10, 34, 26, 18), + (11, 35, 27, 19)], + [(9, 11, 16, 14), (10, 13, 15, 12), (1, 17, 41, 40), (4, 20, 44, 37), + (6, 22, 46, 35)], + [(17, 19, 24, 22), (18, 21, 23, 20), (6, 25, 43, 16), (7, 28, 42, 13), + (8, 30, 41, 11)], + [(25, 27, 32, 30), (26, 29, 31, 28), (3, 38, 43, 19), (5, 36, 45, 21), + (8, 33, 48, 24)], + [(33, 35, 40, 38), (34, 37, 39, 36), (3, 9, 46, 32), (2, 12, 47, 29), + (1, 14, 48, 27)], + [(41, 43, 48, 46), (42, 45, 47, 44), (14, 22, 30, 38), + (15, 23, 31, 39), (16, 24, 32, 40)] + ] + return [Permutation([[i - 1 for i in xi] for xi in x], size=48) for x in a] + + +def rubik(n): + """Return permutations for an nxn Rubik's cube. + + Permutations returned are for rotation of each of the slice + from the face up to the last face for each of the 3 sides (in this order): + front, right and bottom. Hence, the first n - 1 permutations are for the + slices from the front. + """ + + if n < 2: + raise ValueError('dimension of cube must be > 1') + + # 1-based reference to rows and columns in Matrix + def getr(f, i): + return faces[f].col(n - i) + + def getl(f, i): + return faces[f].col(i - 1) + + def getu(f, i): + return faces[f].row(i - 1) + + def getd(f, i): + return faces[f].row(n - i) + + def setr(f, i, s): + faces[f][:, n - i] = Matrix(n, 1, s) + + def setl(f, i, s): + faces[f][:, i - 1] = Matrix(n, 1, s) + + def setu(f, i, s): + faces[f][i - 1, :] = Matrix(1, n, s) + + def setd(f, i, s): + faces[f][n - i, :] = Matrix(1, n, s) + + # motion of a single face + def cw(F, r=1): + for _ in range(r): + face = faces[F] + rv = [] + for c in range(n): + for r in range(n - 1, -1, -1): + rv.append(face[r, c]) + faces[F] = Matrix(n, n, rv) + + def ccw(F): + cw(F, 3) + + # motion of plane i from the F side; + # fcw(0) moves the F face, fcw(1) moves the plane + # just behind the front face, etc... + def fcw(i, r=1): + for _ in range(r): + if i == 0: + cw(F) + i += 1 + temp = getr(L, i) + setr(L, i, list(getu(D, i))) + setu(D, i, list(reversed(getl(R, i)))) + setl(R, i, list(getd(U, i))) + setd(U, i, list(reversed(temp))) + i -= 1 + + def fccw(i): + fcw(i, 3) + + # motion of the entire cube from the F side + def FCW(r=1): + for _ in range(r): + cw(F) + ccw(B) + cw(U) + t = faces[U] + cw(L) + faces[U] = faces[L] + cw(D) + faces[L] = faces[D] + cw(R) + faces[D] = faces[R] + faces[R] = t + + def FCCW(): + FCW(3) + + # motion of the entire cube from the U side + def UCW(r=1): + for _ in range(r): + cw(U) + ccw(D) + t = faces[F] + faces[F] = faces[R] + faces[R] = faces[B] + faces[B] = faces[L] + faces[L] = t + + def UCCW(): + UCW(3) + + # defining the permutations for the cube + + U, F, R, B, L, D = names = symbols('U, F, R, B, L, D') + + # the faces are represented by nxn matrices + faces = {} + count = 0 + for fi in range(6): + f = [] + for a in range(n**2): + f.append(count) + count += 1 + faces[names[fi]] = Matrix(n, n, f) + + # this will either return the value of the current permutation + # (show != 1) or else append the permutation to the group, g + def perm(show=0): + # add perm to the list of perms + p = [] + for f in names: + p.extend(faces[f]) + if show: + return p + g.append(Permutation(p)) + + g = [] # container for the group's permutations + I = list(range(6*n**2)) # the identity permutation used for checking + + # define permutations corresponding to cw rotations of the planes + # up TO the last plane from that direction; by not including the + # last plane, the orientation of the cube is maintained. + + # F slices + for i in range(n - 1): + fcw(i) + perm() + fccw(i) # restore + assert perm(1) == I + + # R slices + # bring R to front + UCW() + for i in range(n - 1): + fcw(i) + # put it back in place + UCCW() + # record + perm() + # restore + # bring face to front + UCW() + fccw(i) + # restore + UCCW() + assert perm(1) == I + + # D slices + # bring up bottom + FCW() + UCCW() + FCCW() + for i in range(n - 1): + # turn strip + fcw(i) + # put bottom back on the bottom + FCW() + UCW() + FCCW() + # record + perm() + # restore + # bring up bottom + FCW() + UCCW() + FCCW() + # turn strip + fccw(i) + # put bottom back on the bottom + FCW() + UCW() + FCCW() + assert perm(1) == I + + return g diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/graycode.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/graycode.py new file mode 100644 index 0000000000000000000000000000000000000000..930fd337862a70e920a985947d74375b27741293 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/graycode.py @@ -0,0 +1,430 @@ +from sympy.core import Basic, Integer + +import random + + +class GrayCode(Basic): + """ + A Gray code is essentially a Hamiltonian walk on + a n-dimensional cube with edge length of one. + The vertices of the cube are represented by vectors + whose values are binary. The Hamilton walk visits + each vertex exactly once. The Gray code for a 3d + cube is ['000','100','110','010','011','111','101', + '001']. + + A Gray code solves the problem of sequentially + generating all possible subsets of n objects in such + a way that each subset is obtained from the previous + one by either deleting or adding a single object. + In the above example, 1 indicates that the object is + present, and 0 indicates that its absent. + + Gray codes have applications in statistics as well when + we want to compute various statistics related to subsets + in an efficient manner. + + Examples + ======== + + >>> from sympy.combinatorics import GrayCode + >>> a = GrayCode(3) + >>> list(a.generate_gray()) + ['000', '001', '011', '010', '110', '111', '101', '100'] + >>> a = GrayCode(4) + >>> list(a.generate_gray()) + ['0000', '0001', '0011', '0010', '0110', '0111', '0101', '0100', \ + '1100', '1101', '1111', '1110', '1010', '1011', '1001', '1000'] + + References + ========== + + .. [1] Nijenhuis,A. and Wilf,H.S.(1978). + Combinatorial Algorithms. Academic Press. + .. [2] Knuth, D. (2011). The Art of Computer Programming, Vol 4 + Addison Wesley + + + """ + + _skip = False + _current = 0 + _rank = None + + def __new__(cls, n, *args, **kw_args): + """ + Default constructor. + + It takes a single argument ``n`` which gives the dimension of the Gray + code. The starting Gray code string (``start``) or the starting ``rank`` + may also be given; the default is to start at rank = 0 ('0...0'). + + Examples + ======== + + >>> from sympy.combinatorics import GrayCode + >>> a = GrayCode(3) + >>> a + GrayCode(3) + >>> a.n + 3 + + >>> a = GrayCode(3, start='100') + >>> a.current + '100' + + >>> a = GrayCode(4, rank=4) + >>> a.current + '0110' + >>> a.rank + 4 + + """ + if n < 1 or int(n) != n: + raise ValueError( + 'Gray code dimension must be a positive integer, not %i' % n) + n = Integer(n) + args = (n,) + args + obj = Basic.__new__(cls, *args) + if 'start' in kw_args: + obj._current = kw_args["start"] + if len(obj._current) > n: + raise ValueError('Gray code start has length %i but ' + 'should not be greater than %i' % (len(obj._current), n)) + elif 'rank' in kw_args: + if int(kw_args["rank"]) != kw_args["rank"]: + raise ValueError('Gray code rank must be a positive integer, ' + 'not %i' % kw_args["rank"]) + obj._rank = int(kw_args["rank"]) % obj.selections + obj._current = obj.unrank(n, obj._rank) + return obj + + def next(self, delta=1): + """ + Returns the Gray code a distance ``delta`` (default = 1) from the + current value in canonical order. + + + Examples + ======== + + >>> from sympy.combinatorics import GrayCode + >>> a = GrayCode(3, start='110') + >>> a.next().current + '111' + >>> a.next(-1).current + '010' + """ + return GrayCode(self.n, rank=(self.rank + delta) % self.selections) + + @property + def selections(self): + """ + Returns the number of bit vectors in the Gray code. + + Examples + ======== + + >>> from sympy.combinatorics import GrayCode + >>> a = GrayCode(3) + >>> a.selections + 8 + """ + return 2**self.n + + @property + def n(self): + """ + Returns the dimension of the Gray code. + + Examples + ======== + + >>> from sympy.combinatorics import GrayCode + >>> a = GrayCode(5) + >>> a.n + 5 + """ + return self.args[0] + + def generate_gray(self, **hints): + """ + Generates the sequence of bit vectors of a Gray Code. + + Examples + ======== + + >>> from sympy.combinatorics import GrayCode + >>> a = GrayCode(3) + >>> list(a.generate_gray()) + ['000', '001', '011', '010', '110', '111', '101', '100'] + >>> list(a.generate_gray(start='011')) + ['011', '010', '110', '111', '101', '100'] + >>> list(a.generate_gray(rank=4)) + ['110', '111', '101', '100'] + + See Also + ======== + + skip + + References + ========== + + .. [1] Knuth, D. (2011). The Art of Computer Programming, + Vol 4, Addison Wesley + + """ + bits = self.n + start = None + if "start" in hints: + start = hints["start"] + elif "rank" in hints: + start = GrayCode.unrank(self.n, hints["rank"]) + if start is not None: + self._current = start + current = self.current + graycode_bin = gray_to_bin(current) + if len(graycode_bin) > self.n: + raise ValueError('Gray code start has length %i but should ' + 'not be greater than %i' % (len(graycode_bin), bits)) + self._current = int(current, 2) + graycode_int = int(''.join(graycode_bin), 2) + for i in range(graycode_int, 1 << bits): + if self._skip: + self._skip = False + else: + yield self.current + bbtc = (i ^ (i + 1)) + gbtc = (bbtc ^ (bbtc >> 1)) + self._current = (self._current ^ gbtc) + self._current = 0 + + def skip(self): + """ + Skips the bit generation. + + Examples + ======== + + >>> from sympy.combinatorics import GrayCode + >>> a = GrayCode(3) + >>> for i in a.generate_gray(): + ... if i == '010': + ... a.skip() + ... print(i) + ... + 000 + 001 + 011 + 010 + 111 + 101 + 100 + + See Also + ======== + + generate_gray + """ + self._skip = True + + @property + def rank(self): + """ + Ranks the Gray code. + + A ranking algorithm determines the position (or rank) + of a combinatorial object among all the objects w.r.t. + a given order. For example, the 4 bit binary reflected + Gray code (BRGC) '0101' has a rank of 6 as it appears in + the 6th position in the canonical ordering of the family + of 4 bit Gray codes. + + Examples + ======== + + >>> from sympy.combinatorics import GrayCode + >>> a = GrayCode(3) + >>> list(a.generate_gray()) + ['000', '001', '011', '010', '110', '111', '101', '100'] + >>> GrayCode(3, start='100').rank + 7 + >>> GrayCode(3, rank=7).current + '100' + + See Also + ======== + + unrank + + References + ========== + + .. [1] https://web.archive.org/web/20200224064753/http://statweb.stanford.edu/~susan/courses/s208/node12.html + + """ + if self._rank is None: + self._rank = int(gray_to_bin(self.current), 2) + return self._rank + + @property + def current(self): + """ + Returns the currently referenced Gray code as a bit string. + + Examples + ======== + + >>> from sympy.combinatorics import GrayCode + >>> GrayCode(3, start='100').current + '100' + """ + rv = self._current or '0' + if not isinstance(rv, str): + rv = bin(rv)[2:] + return rv.rjust(self.n, '0') + + @classmethod + def unrank(self, n, rank): + """ + Unranks an n-bit sized Gray code of rank k. This method exists + so that a derivative GrayCode class can define its own code of + a given rank. + + The string here is generated in reverse order to allow for tail-call + optimization. + + Examples + ======== + + >>> from sympy.combinatorics import GrayCode + >>> GrayCode(5, rank=3).current + '00010' + >>> GrayCode.unrank(5, 3) + '00010' + + See Also + ======== + + rank + """ + def _unrank(k, n): + if n == 1: + return str(k % 2) + m = 2**(n - 1) + if k < m: + return '0' + _unrank(k, n - 1) + return '1' + _unrank(m - (k % m) - 1, n - 1) + return _unrank(rank, n) + + +def random_bitstring(n): + """ + Generates a random bitlist of length n. + + Examples + ======== + + >>> from sympy.combinatorics.graycode import random_bitstring + >>> random_bitstring(3) # doctest: +SKIP + 100 + """ + return ''.join([random.choice('01') for i in range(n)]) + + +def gray_to_bin(bin_list): + """ + Convert from Gray coding to binary coding. + + We assume big endian encoding. + + Examples + ======== + + >>> from sympy.combinatorics.graycode import gray_to_bin + >>> gray_to_bin('100') + '111' + + See Also + ======== + + bin_to_gray + """ + b = [bin_list[0]] + for i in range(1, len(bin_list)): + b += str(int(b[i - 1] != bin_list[i])) + return ''.join(b) + + +def bin_to_gray(bin_list): + """ + Convert from binary coding to gray coding. + + We assume big endian encoding. + + Examples + ======== + + >>> from sympy.combinatorics.graycode import bin_to_gray + >>> bin_to_gray('111') + '100' + + See Also + ======== + + gray_to_bin + """ + b = [bin_list[0]] + for i in range(1, len(bin_list)): + b += str(int(bin_list[i]) ^ int(bin_list[i - 1])) + return ''.join(b) + + +def get_subset_from_bitstring(super_set, bitstring): + """ + Gets the subset defined by the bitstring. + + Examples + ======== + + >>> from sympy.combinatorics.graycode import get_subset_from_bitstring + >>> get_subset_from_bitstring(['a', 'b', 'c', 'd'], '0011') + ['c', 'd'] + >>> get_subset_from_bitstring(['c', 'a', 'c', 'c'], '1100') + ['c', 'a'] + + See Also + ======== + + graycode_subsets + """ + if len(super_set) != len(bitstring): + raise ValueError("The sizes of the lists are not equal") + return [super_set[i] for i, j in enumerate(bitstring) + if bitstring[i] == '1'] + + +def graycode_subsets(gray_code_set): + """ + Generates the subsets as enumerated by a Gray code. + + Examples + ======== + + >>> from sympy.combinatorics.graycode import graycode_subsets + >>> list(graycode_subsets(['a', 'b', 'c'])) + [[], ['c'], ['b', 'c'], ['b'], ['a', 'b'], ['a', 'b', 'c'], \ + ['a', 'c'], ['a']] + >>> list(graycode_subsets(['a', 'b', 'c', 'c'])) + [[], ['c'], ['c', 'c'], ['c'], ['b', 'c'], ['b', 'c', 'c'], \ + ['b', 'c'], ['b'], ['a', 'b'], ['a', 'b', 'c'], ['a', 'b', 'c', 'c'], \ + ['a', 'b', 'c'], ['a', 'c'], ['a', 'c', 'c'], ['a', 'c'], ['a']] + + See Also + ======== + + get_subset_from_bitstring + """ + for bitstring in list(GrayCode(len(gray_code_set)).generate_gray()): + yield get_subset_from_bitstring(gray_code_set, bitstring) diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/group_constructs.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/group_constructs.py new file mode 100644 index 0000000000000000000000000000000000000000..a5c16ec254191646b26eee869763e2926e187da5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/group_constructs.py @@ -0,0 +1,61 @@ +from sympy.combinatorics.perm_groups import PermutationGroup +from sympy.combinatorics.permutations import Permutation +from sympy.utilities.iterables import uniq + +_af_new = Permutation._af_new + + +def DirectProduct(*groups): + """ + Returns the direct product of several groups as a permutation group. + + Explanation + =========== + + This is implemented much like the __mul__ procedure for taking the direct + product of two permutation groups, but the idea of shifting the + generators is realized in the case of an arbitrary number of groups. + A call to DirectProduct(G1, G2, ..., Gn) is generally expected to be faster + than a call to G1*G2*...*Gn (and thus the need for this algorithm). + + Examples + ======== + + >>> from sympy.combinatorics.group_constructs import DirectProduct + >>> from sympy.combinatorics.named_groups import CyclicGroup + >>> C = CyclicGroup(4) + >>> G = DirectProduct(C, C, C) + >>> G.order() + 64 + + See Also + ======== + + sympy.combinatorics.perm_groups.PermutationGroup.__mul__ + + """ + degrees = [] + gens_count = [] + total_degree = 0 + total_gens = 0 + for group in groups: + current_deg = group.degree + current_num_gens = len(group.generators) + degrees.append(current_deg) + total_degree += current_deg + gens_count.append(current_num_gens) + total_gens += current_num_gens + array_gens = [] + for i in range(total_gens): + array_gens.append(list(range(total_degree))) + current_gen = 0 + current_deg = 0 + for i in range(len(gens_count)): + for j in range(current_gen, current_gen + gens_count[i]): + gen = ((groups[i].generators)[j - current_gen]).array_form + array_gens[j][current_deg:current_deg + degrees[i]] = \ + [x + current_deg for x in gen] + current_gen += gens_count[i] + current_deg += degrees[i] + perm_gens = list(uniq([_af_new(list(a)) for a in array_gens])) + return PermutationGroup(perm_gens, dups=False) diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/group_numbers.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/group_numbers.py new file mode 100644 index 0000000000000000000000000000000000000000..8099dfc2ed8a6943aa1261f9374ed84f0ad3c522 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/group_numbers.py @@ -0,0 +1,294 @@ +from itertools import chain, combinations + +from sympy.external.gmpy import gcd +from sympy.ntheory.factor_ import factorint +from sympy.utilities.misc import as_int + + +def _is_nilpotent_number(factors: dict) -> bool: + """ Check whether `n` is a nilpotent number. + Note that ``factors`` is a prime factorization of `n`. + + This is a low-level helper for ``is_nilpotent_number``, for internal use. + """ + for p in factors.keys(): + for q, e in factors.items(): + # We want to calculate + # any(pow(q, k, p) == 1 for k in range(1, e + 1)) + m = 1 + for _ in range(e): + m = m*q % p + if m == 1: + return False + return True + + +def is_nilpotent_number(n) -> bool: + """ + Check whether `n` is a nilpotent number. A number `n` is said to be + nilpotent if and only if every finite group of order `n` is nilpotent. + For more information see [1]_. + + Examples + ======== + + >>> from sympy.combinatorics.group_numbers import is_nilpotent_number + >>> from sympy import randprime + >>> is_nilpotent_number(21) + False + >>> is_nilpotent_number(randprime(1, 30)**12) + True + + References + ========== + + .. [1] Pakianathan, J., Shankar, K., Nilpotent Numbers, + The American Mathematical Monthly, 107(7), 631-634. + .. [2] https://oeis.org/A056867 + + """ + n = as_int(n) + if n <= 0: + raise ValueError("n must be a positive integer, not %i" % n) + return _is_nilpotent_number(factorint(n)) + + +def is_abelian_number(n) -> bool: + """ + Check whether `n` is an abelian number. A number `n` is said to be abelian + if and only if every finite group of order `n` is abelian. For more + information see [1]_. + + Examples + ======== + + >>> from sympy.combinatorics.group_numbers import is_abelian_number + >>> from sympy import randprime + >>> is_abelian_number(4) + True + >>> is_abelian_number(randprime(1, 2000)**2) + True + >>> is_abelian_number(60) + False + + References + ========== + + .. [1] Pakianathan, J., Shankar, K., Nilpotent Numbers, + The American Mathematical Monthly, 107(7), 631-634. + .. [2] https://oeis.org/A051532 + + """ + n = as_int(n) + if n <= 0: + raise ValueError("n must be a positive integer, not %i" % n) + factors = factorint(n) + return all(e < 3 for e in factors.values()) and _is_nilpotent_number(factors) + + +def is_cyclic_number(n) -> bool: + """ + Check whether `n` is a cyclic number. A number `n` is said to be cyclic + if and only if every finite group of order `n` is cyclic. For more + information see [1]_. + + Examples + ======== + + >>> from sympy.combinatorics.group_numbers import is_cyclic_number + >>> from sympy import randprime + >>> is_cyclic_number(15) + True + >>> is_cyclic_number(randprime(1, 2000)**2) + False + >>> is_cyclic_number(4) + False + + References + ========== + + .. [1] Pakianathan, J., Shankar, K., Nilpotent Numbers, + The American Mathematical Monthly, 107(7), 631-634. + .. [2] https://oeis.org/A003277 + + """ + n = as_int(n) + if n <= 0: + raise ValueError("n must be a positive integer, not %i" % n) + factors = factorint(n) + return all(e == 1 for e in factors.values()) and _is_nilpotent_number(factors) + + +def _holder_formula(prime_factors): + r""" Number of groups of order `n`. + where `n` is squarefree and its prime factors are ``prime_factors``. + i.e., ``n == math.prod(prime_factors)`` + + Explanation + =========== + + When `n` is squarefree, the number of groups of order `n` is expressed by + + .. math :: + \sum_{d \mid n} \prod_p \frac{p^{c(p, d)} - 1}{p - 1} + + where `n=de`, `p` is the prime factor of `e`, + and `c(p, d)` is the number of prime factors `q` of `d` such that `q \equiv 1 \pmod{p}` [2]_. + + The formula is elegant, but can be improved when implemented as an algorithm. + Since `n` is assumed to be squarefree, the divisor `d` of `n` can be identified with the power set of prime factors. + We let `N` be the set of prime factors of `n`. + `F = \{p \in N : \forall q \in N, q \not\equiv 1 \pmod{p} \}, M = N \setminus F`, we have the following. + + .. math :: + \sum_{d \in 2^{M}} \prod_{p \in M \setminus d} \frac{p^{c(p, F \cup d)} - 1}{p - 1} + + Practically, many prime factors are expected to be members of `F`, thus reducing computation time. + + Parameters + ========== + + prime_factors : set + The set of prime factors of ``n``. where `n` is squarefree. + + Returns + ======= + + int : Number of groups of order ``n`` + + Examples + ======== + + >>> from sympy.combinatorics.group_numbers import _holder_formula + >>> _holder_formula({2}) # n = 2 + 1 + >>> _holder_formula({2, 3}) # n = 2*3 = 6 + 2 + + See Also + ======== + + groups_count + + References + ========== + + .. [1] Otto Holder, Die Gruppen der Ordnungen p^3, pq^2, pqr, p^4, + Math. Ann. 43 pp. 301-412 (1893). + http://dx.doi.org/10.1007/BF01443651 + .. [2] John H. Conway, Heiko Dietrich and E.A. O'Brien, + Counting groups: gnus, moas and other exotica + The Mathematical Intelligencer 30, 6-15 (2008) + https://doi.org/10.1007/BF02985731 + + """ + F = {p for p in prime_factors if all(q % p != 1 for q in prime_factors)} + M = prime_factors - F + + s = 0 + powerset = chain.from_iterable(combinations(M, r) for r in range(len(M)+1)) + for ps in powerset: + ps = set(ps) + prod = 1 + for p in M - ps: + c = len([q for q in F | ps if q % p == 1]) + prod *= (p**c - 1) // (p - 1) + if not prod: + break + s += prod + return s + + +def groups_count(n): + r""" Number of groups of order `n`. + In [1]_, ``gnu(n)`` is given, so we follow this notation here as well. + + Parameters + ========== + + n : Integer + ``n`` is a positive integer + + Returns + ======= + + int : ``gnu(n)`` + + Raises + ====== + + ValueError + Number of groups of order ``n`` is unknown or not implemented. + For example, gnu(`2^{11}`) is not yet known. + On the other hand, gnu(99) is known to be 2, + but this has not yet been implemented in this function. + + Examples + ======== + + >>> from sympy.combinatorics.group_numbers import groups_count + >>> groups_count(3) # There is only one cyclic group of order 3 + 1 + >>> # There are two groups of order 10: the cyclic group and the dihedral group + >>> groups_count(10) + 2 + + See Also + ======== + + is_cyclic_number + `n` is cyclic iff gnu(n) = 1 + + References + ========== + + .. [1] John H. Conway, Heiko Dietrich and E.A. O'Brien, + Counting groups: gnus, moas and other exotica + The Mathematical Intelligencer 30, 6-15 (2008) + https://doi.org/10.1007/BF02985731 + .. [2] https://oeis.org/A000001 + + """ + n = as_int(n) + if n <= 0: + raise ValueError("n must be a positive integer, not %i" % n) + factors = factorint(n) + if len(factors) == 1: + (p, e) = list(factors.items())[0] + if p == 2: + A000679 = [1, 1, 2, 5, 14, 51, 267, 2328, 56092, 10494213, 49487367289] + if e < len(A000679): + return A000679[e] + if p == 3: + A090091 = [1, 1, 2, 5, 15, 67, 504, 9310, 1396077, 5937876645] + if e < len(A090091): + return A090091[e] + if e <= 2: # gnu(p) = 1, gnu(p**2) = 2 + return e + if e == 3: # gnu(p**3) = 5 + return 5 + if e == 4: # if p is an odd prime, gnu(p**4) = 15 + return 15 + if e == 5: # if p >= 5, gnu(p**5) is expressed by the following equation + return 61 + 2*p + 2*gcd(p-1, 3) + gcd(p-1, 4) + if e == 6: # if p >= 6, gnu(p**6) is expressed by the following equation + return 3*p**2 + 39*p + 344 +\ + 24*gcd(p-1, 3) + 11*gcd(p-1, 4) + 2*gcd(p-1, 5) + if e == 7: # if p >= 7, gnu(p**7) is expressed by the following equation + if p == 5: + return 34297 + return 3*p**5 + 12*p**4 + 44*p**3 + 170*p**2 + 707*p + 2455 +\ + (4*p**2 + 44*p + 291)*gcd(p-1, 3) + (p**2 + 19*p + 135)*gcd(p-1, 4) + \ + (3*p + 31)*gcd(p-1, 5) + 4*gcd(p-1, 7) + 5*gcd(p-1, 8) + gcd(p-1, 9) + if any(e > 1 for e in factors.values()): # n is not squarefree + # some known values for small n that have more than 1 factor and are not square free (https://oeis.org/A000001) + small = {12: 5, 18: 5, 20: 5, 24: 15, 28: 4, 36: 14, 40: 14, 44: 4, 45: 2, 48: 52, + 50: 5, 52: 5, 54: 15, 56: 13, 60: 13, 63: 4, 68: 5, 72: 50, 75: 3, 76: 4, + 80: 52, 84: 15, 88: 12, 90: 10, 92: 4} + if n in small: + return small[n] + raise ValueError("Number of groups of order n is unknown or not implemented") + if len(factors) == 2: # n is squarefree semiprime + p, q = sorted(factors.keys()) + return 2 if q % p == 1 else 1 + return _holder_formula(set(factors.keys())) diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/homomorphisms.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/homomorphisms.py new file mode 100644 index 0000000000000000000000000000000000000000..256f56b3aa7c5404d332f42e37d4f1117ea81db7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/homomorphisms.py @@ -0,0 +1,549 @@ +import itertools +from sympy.combinatorics.fp_groups import FpGroup, FpSubgroup, simplify_presentation +from sympy.combinatorics.free_groups import FreeGroup +from sympy.combinatorics.perm_groups import PermutationGroup +from sympy.core.intfunc import igcd +from sympy.functions.combinatorial.numbers import totient +from sympy.core.singleton import S + +class GroupHomomorphism: + ''' + A class representing group homomorphisms. Instantiate using `homomorphism()`. + + References + ========== + + .. [1] Holt, D., Eick, B. and O'Brien, E. (2005). Handbook of computational group theory. + + ''' + + def __init__(self, domain, codomain, images): + self.domain = domain + self.codomain = codomain + self.images = images + self._inverses = None + self._kernel = None + self._image = None + + def _invs(self): + ''' + Return a dictionary with `{gen: inverse}` where `gen` is a rewriting + generator of `codomain` (e.g. strong generator for permutation groups) + and `inverse` is an element of its preimage + + ''' + image = self.image() + inverses = {} + for k in list(self.images.keys()): + v = self.images[k] + if not (v in inverses + or v.is_identity): + inverses[v] = k + if isinstance(self.codomain, PermutationGroup): + gens = image.strong_gens + else: + gens = image.generators + for g in gens: + if g in inverses or g.is_identity: + continue + w = self.domain.identity + if isinstance(self.codomain, PermutationGroup): + parts = image._strong_gens_slp[g][::-1] + else: + parts = g + for s in parts: + if s in inverses: + w = w*inverses[s] + else: + w = w*inverses[s**-1]**-1 + inverses[g] = w + + return inverses + + def invert(self, g): + ''' + Return an element of the preimage of ``g`` or of each element + of ``g`` if ``g`` is a list. + + Explanation + =========== + + If the codomain is an FpGroup, the inverse for equal + elements might not always be the same unless the FpGroup's + rewriting system is confluent. However, making a system + confluent can be time-consuming. If it's important, try + `self.codomain.make_confluent()` first. + + ''' + from sympy.combinatorics import Permutation + from sympy.combinatorics.free_groups import FreeGroupElement + if isinstance(g, (Permutation, FreeGroupElement)): + if isinstance(self.codomain, FpGroup): + g = self.codomain.reduce(g) + if self._inverses is None: + self._inverses = self._invs() + image = self.image() + w = self.domain.identity + if isinstance(self.codomain, PermutationGroup): + gens = image.generator_product(g)[::-1] + else: + gens = g + # the following can't be "for s in gens:" + # because that would be equivalent to + # "for s in gens.array_form:" when g is + # a FreeGroupElement. On the other hand, + # when you call gens by index, the generator + # (or inverse) at position i is returned. + for i in range(len(gens)): + s = gens[i] + if s.is_identity: + continue + if s in self._inverses: + w = w*self._inverses[s] + else: + w = w*self._inverses[s**-1]**-1 + return w + elif isinstance(g, list): + return [self.invert(e) for e in g] + + def kernel(self): + ''' + Compute the kernel of `self`. + + ''' + if self._kernel is None: + self._kernel = self._compute_kernel() + return self._kernel + + def _compute_kernel(self): + G = self.domain + G_order = G.order() + if G_order is S.Infinity: + raise NotImplementedError( + "Kernel computation is not implemented for infinite groups") + gens = [] + if isinstance(G, PermutationGroup): + K = PermutationGroup(G.identity) + else: + K = FpSubgroup(G, gens, normal=True) + i = self.image().order() + while K.order()*i != G_order: + r = G.random() + k = r*self.invert(self(r))**-1 + if k not in K: + gens.append(k) + if isinstance(G, PermutationGroup): + K = PermutationGroup(gens) + else: + K = FpSubgroup(G, gens, normal=True) + return K + + def image(self): + ''' + Compute the image of `self`. + + ''' + if self._image is None: + values = list(set(self.images.values())) + if isinstance(self.codomain, PermutationGroup): + self._image = self.codomain.subgroup(values) + else: + self._image = FpSubgroup(self.codomain, values) + return self._image + + def _apply(self, elem): + ''' + Apply `self` to `elem`. + + ''' + if elem not in self.domain: + if isinstance(elem, (list, tuple)): + return [self._apply(e) for e in elem] + raise ValueError("The supplied element does not belong to the domain") + if elem.is_identity: + return self.codomain.identity + else: + images = self.images + value = self.codomain.identity + if isinstance(self.domain, PermutationGroup): + gens = self.domain.generator_product(elem, original=True) + for g in gens: + if g in self.images: + value = images[g]*value + else: + value = images[g**-1]**-1*value + else: + i = 0 + for _, p in elem.array_form: + if p < 0: + g = elem[i]**-1 + else: + g = elem[i] + value = value*images[g]**p + i += abs(p) + return value + + def __call__(self, elem): + return self._apply(elem) + + def is_injective(self): + ''' + Check if the homomorphism is injective + + ''' + return self.kernel().order() == 1 + + def is_surjective(self): + ''' + Check if the homomorphism is surjective + + ''' + im = self.image().order() + oth = self.codomain.order() + if im is S.Infinity and oth is S.Infinity: + return None + else: + return im == oth + + def is_isomorphism(self): + ''' + Check if `self` is an isomorphism. + + ''' + return self.is_injective() and self.is_surjective() + + def is_trivial(self): + ''' + Check is `self` is a trivial homomorphism, i.e. all elements + are mapped to the identity. + + ''' + return self.image().order() == 1 + + def compose(self, other): + ''' + Return the composition of `self` and `other`, i.e. + the homomorphism phi such that for all g in the domain + of `other`, phi(g) = self(other(g)) + + ''' + if not other.image().is_subgroup(self.domain): + raise ValueError("The image of `other` must be a subgroup of " + "the domain of `self`") + images = {g: self(other(g)) for g in other.images} + return GroupHomomorphism(other.domain, self.codomain, images) + + def restrict_to(self, H): + ''' + Return the restriction of the homomorphism to the subgroup `H` + of the domain. + + ''' + if not isinstance(H, PermutationGroup) or not H.is_subgroup(self.domain): + raise ValueError("Given H is not a subgroup of the domain") + domain = H + images = {g: self(g) for g in H.generators} + return GroupHomomorphism(domain, self.codomain, images) + + def invert_subgroup(self, H): + ''' + Return the subgroup of the domain that is the inverse image + of the subgroup ``H`` of the homomorphism image + + ''' + if not H.is_subgroup(self.image()): + raise ValueError("Given H is not a subgroup of the image") + gens = [] + P = PermutationGroup(self.image().identity) + for h in H.generators: + h_i = self.invert(h) + if h_i not in P: + gens.append(h_i) + P = PermutationGroup(gens) + for k in self.kernel().generators: + if k*h_i not in P: + gens.append(k*h_i) + P = PermutationGroup(gens) + return P + +def homomorphism(domain, codomain, gens, images=(), check=True): + ''' + Create (if possible) a group homomorphism from the group ``domain`` + to the group ``codomain`` defined by the images of the domain's + generators ``gens``. ``gens`` and ``images`` can be either lists or tuples + of equal sizes. If ``gens`` is a proper subset of the group's generators, + the unspecified generators will be mapped to the identity. If the + images are not specified, a trivial homomorphism will be created. + + If the given images of the generators do not define a homomorphism, + an exception is raised. + + If ``check`` is ``False``, do not check whether the given images actually + define a homomorphism. + + ''' + if not isinstance(domain, (PermutationGroup, FpGroup, FreeGroup)): + raise TypeError("The domain must be a group") + if not isinstance(codomain, (PermutationGroup, FpGroup, FreeGroup)): + raise TypeError("The codomain must be a group") + + generators = domain.generators + if not all(g in generators for g in gens): + raise ValueError("The supplied generators must be a subset of the domain's generators") + if not all(g in codomain for g in images): + raise ValueError("The images must be elements of the codomain") + + if images and len(images) != len(gens): + raise ValueError("The number of images must be equal to the number of generators") + + gens = list(gens) + images = list(images) + + images.extend([codomain.identity]*(len(generators)-len(images))) + gens.extend([g for g in generators if g not in gens]) + images = dict(zip(gens,images)) + + if check and not _check_homomorphism(domain, codomain, images): + raise ValueError("The given images do not define a homomorphism") + return GroupHomomorphism(domain, codomain, images) + +def _check_homomorphism(domain, codomain, images): + """ + Check that a given mapping of generators to images defines a homomorphism. + + Parameters + ========== + domain : PermutationGroup, FpGroup, FreeGroup + codomain : PermutationGroup, FpGroup, FreeGroup + images : dict + The set of keys must be equal to domain.generators. + The values must be elements of the codomain. + + """ + pres = domain if hasattr(domain, 'relators') else domain.presentation() + rels = pres.relators + gens = pres.generators + symbols = [g.ext_rep[0] for g in gens] + symbols_to_domain_generators = dict(zip(symbols, domain.generators)) + identity = codomain.identity + + def _image(r): + w = identity + for symbol, power in r.array_form: + g = symbols_to_domain_generators[symbol] + w *= images[g]**power + return w + + for r in rels: + if isinstance(codomain, FpGroup): + s = codomain.equals(_image(r), identity) + if s is None: + # only try to make the rewriting system + # confluent when it can't determine the + # truth of equality otherwise + success = codomain.make_confluent() + s = codomain.equals(_image(r), identity) + if s is None and not success: + raise RuntimeError("Can't determine if the images " + "define a homomorphism. Try increasing " + "the maximum number of rewriting rules " + "(group._rewriting_system.set_max(new_value); " + "the current value is stored in group._rewriting" + "_system.maxeqns)") + else: + s = _image(r).is_identity + if not s: + return False + return True + +def orbit_homomorphism(group, omega): + ''' + Return the homomorphism induced by the action of the permutation + group ``group`` on the set ``omega`` that is closed under the action. + + ''' + from sympy.combinatorics import Permutation + from sympy.combinatorics.named_groups import SymmetricGroup + codomain = SymmetricGroup(len(omega)) + identity = codomain.identity + omega = list(omega) + images = {g: identity*Permutation([omega.index(o^g) for o in omega]) for g in group.generators} + group._schreier_sims(base=omega) + H = GroupHomomorphism(group, codomain, images) + if len(group.basic_stabilizers) > len(omega): + H._kernel = group.basic_stabilizers[len(omega)] + else: + H._kernel = PermutationGroup([group.identity]) + return H + +def block_homomorphism(group, blocks): + ''' + Return the homomorphism induced by the action of the permutation + group ``group`` on the block system ``blocks``. The latter should be + of the same form as returned by the ``minimal_block`` method for + permutation groups, namely a list of length ``group.degree`` where + the i-th entry is a representative of the block i belongs to. + + ''' + from sympy.combinatorics import Permutation + from sympy.combinatorics.named_groups import SymmetricGroup + + n = len(blocks) + + # number the blocks; m is the total number, + # b is such that b[i] is the number of the block i belongs to, + # p is the list of length m such that p[i] is the representative + # of the i-th block + m = 0 + p = [] + b = [None]*n + for i in range(n): + if blocks[i] == i: + p.append(i) + b[i] = m + m += 1 + for i in range(n): + b[i] = b[blocks[i]] + + codomain = SymmetricGroup(m) + # the list corresponding to the identity permutation in codomain + identity = range(m) + images = {g: Permutation([b[p[i]^g] for i in identity]) for g in group.generators} + H = GroupHomomorphism(group, codomain, images) + return H + +def group_isomorphism(G, H, isomorphism=True): + ''' + Compute an isomorphism between 2 given groups. + + Parameters + ========== + + G : A finite ``FpGroup`` or a ``PermutationGroup``. + First group. + + H : A finite ``FpGroup`` or a ``PermutationGroup`` + Second group. + + isomorphism : bool + This is used to avoid the computation of homomorphism + when the user only wants to check if there exists + an isomorphism between the groups. + + Returns + ======= + + If isomorphism = False -- Returns a boolean. + If isomorphism = True -- Returns a boolean and an isomorphism between `G` and `H`. + + Examples + ======== + + >>> from sympy.combinatorics import free_group, Permutation + >>> from sympy.combinatorics.perm_groups import PermutationGroup + >>> from sympy.combinatorics.fp_groups import FpGroup + >>> from sympy.combinatorics.homomorphisms import group_isomorphism + >>> from sympy.combinatorics.named_groups import DihedralGroup, AlternatingGroup + + >>> D = DihedralGroup(8) + >>> p = Permutation(0, 1, 2, 3, 4, 5, 6, 7) + >>> P = PermutationGroup(p) + >>> group_isomorphism(D, P) + (False, None) + + >>> F, a, b = free_group("a, b") + >>> G = FpGroup(F, [a**3, b**3, (a*b)**2]) + >>> H = AlternatingGroup(4) + >>> (check, T) = group_isomorphism(G, H) + >>> check + True + >>> T(b*a*b**-1*a**-1*b**-1) + (0 2 3) + + Notes + ===== + + Uses the approach suggested by Robert Tarjan to compute the isomorphism between two groups. + First, the generators of ``G`` are mapped to the elements of ``H`` and + we check if the mapping induces an isomorphism. + + ''' + if not isinstance(G, (PermutationGroup, FpGroup)): + raise TypeError("The group must be a PermutationGroup or an FpGroup") + if not isinstance(H, (PermutationGroup, FpGroup)): + raise TypeError("The group must be a PermutationGroup or an FpGroup") + + if isinstance(G, FpGroup) and isinstance(H, FpGroup): + G = simplify_presentation(G) + H = simplify_presentation(H) + # Two infinite FpGroups with the same generators are isomorphic + # when the relators are same but are ordered differently. + if G.generators == H.generators and (G.relators).sort() == (H.relators).sort(): + if not isomorphism: + return True + return (True, homomorphism(G, H, G.generators, H.generators)) + + # `_H` is the permutation group isomorphic to `H`. + _H = H + g_order = G.order() + h_order = H.order() + + if g_order is S.Infinity: + raise NotImplementedError("Isomorphism methods are not implemented for infinite groups.") + + if isinstance(H, FpGroup): + if h_order is S.Infinity: + raise NotImplementedError("Isomorphism methods are not implemented for infinite groups.") + _H, h_isomorphism = H._to_perm_group() + + if (g_order != h_order) or (G.is_abelian != H.is_abelian): + if not isomorphism: + return False + return (False, None) + + if not isomorphism: + # Two groups of the same cyclic numbered order + # are isomorphic to each other. + n = g_order + if (igcd(n, totient(n))) == 1: + return True + + # Match the generators of `G` with subsets of `_H` + gens = list(G.generators) + for subset in itertools.permutations(_H, len(gens)): + images = list(subset) + images.extend([_H.identity]*(len(G.generators)-len(images))) + _images = dict(zip(gens,images)) + if _check_homomorphism(G, _H, _images): + if isinstance(H, FpGroup): + images = h_isomorphism.invert(images) + T = homomorphism(G, H, G.generators, images, check=False) + if T.is_isomorphism(): + # It is a valid isomorphism + if not isomorphism: + return True + return (True, T) + + if not isomorphism: + return False + return (False, None) + +def is_isomorphic(G, H): + ''' + Check if the groups are isomorphic to each other + + Parameters + ========== + + G : A finite ``FpGroup`` or a ``PermutationGroup`` + First group. + + H : A finite ``FpGroup`` or a ``PermutationGroup`` + Second group. + + Returns + ======= + + boolean + ''' + return group_isomorphism(G, H, isomorphism=False) diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/named_groups.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/named_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..59f10c40ef716e3b644e00f936323e9f6936eb88 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/named_groups.py @@ -0,0 +1,332 @@ +from sympy.combinatorics.group_constructs import DirectProduct +from sympy.combinatorics.perm_groups import PermutationGroup +from sympy.combinatorics.permutations import Permutation + +_af_new = Permutation._af_new + + +def AbelianGroup(*cyclic_orders): + """ + Returns the direct product of cyclic groups with the given orders. + + Explanation + =========== + + According to the structure theorem for finite abelian groups ([1]), + every finite abelian group can be written as the direct product of + finitely many cyclic groups. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import AbelianGroup + >>> AbelianGroup(3, 4) + PermutationGroup([ + (6)(0 1 2), + (3 4 5 6)]) + >>> _.is_group + True + + See Also + ======== + + DirectProduct + + References + ========== + + .. [1] https://groupprops.subwiki.org/wiki/Structure_theorem_for_finitely_generated_abelian_groups + + """ + groups = [] + degree = 0 + order = 1 + for size in cyclic_orders: + degree += size + order *= size + groups.append(CyclicGroup(size)) + G = DirectProduct(*groups) + G._is_abelian = True + G._degree = degree + G._order = order + + return G + + +def AlternatingGroup(n): + """ + Generates the alternating group on ``n`` elements as a permutation group. + + Explanation + =========== + + For ``n > 2``, the generators taken are ``(0 1 2), (0 1 2 ... n-1)`` for + ``n`` odd + and ``(0 1 2), (1 2 ... n-1)`` for ``n`` even (See [1], p.31, ex.6.9.). + After the group is generated, some of its basic properties are set. + The cases ``n = 1, 2`` are handled separately. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import AlternatingGroup + >>> G = AlternatingGroup(4) + >>> G.is_group + True + >>> a = list(G.generate_dimino()) + >>> len(a) + 12 + >>> all(perm.is_even for perm in a) + True + + See Also + ======== + + SymmetricGroup, CyclicGroup, DihedralGroup + + References + ========== + + .. [1] Armstrong, M. "Groups and Symmetry" + + """ + # small cases are special + if n in (1, 2): + return PermutationGroup([Permutation([0])]) + + a = list(range(n)) + a[0], a[1], a[2] = a[1], a[2], a[0] + gen1 = a + if n % 2: + a = list(range(1, n)) + a.append(0) + gen2 = a + else: + a = list(range(2, n)) + a.append(1) + a.insert(0, 0) + gen2 = a + gens = [gen1, gen2] + if gen1 == gen2: + gens = gens[:1] + G = PermutationGroup([_af_new(a) for a in gens], dups=False) + + set_alternating_group_properties(G, n, n) + G._is_alt = True + return G + + +def set_alternating_group_properties(G, n, degree): + """Set known properties of an alternating group. """ + if n < 4: + G._is_abelian = True + G._is_nilpotent = True + else: + G._is_abelian = False + G._is_nilpotent = False + if n < 5: + G._is_solvable = True + else: + G._is_solvable = False + G._degree = degree + G._is_transitive = True + G._is_dihedral = False + + +def CyclicGroup(n): + """ + Generates the cyclic group of order ``n`` as a permutation group. + + Explanation + =========== + + The generator taken is the ``n``-cycle ``(0 1 2 ... n-1)`` + (in cycle notation). After the group is generated, some of its basic + properties are set. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import CyclicGroup + >>> G = CyclicGroup(6) + >>> G.is_group + True + >>> G.order() + 6 + >>> list(G.generate_schreier_sims(af=True)) + [[0, 1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 0], [2, 3, 4, 5, 0, 1], + [3, 4, 5, 0, 1, 2], [4, 5, 0, 1, 2, 3], [5, 0, 1, 2, 3, 4]] + + See Also + ======== + + SymmetricGroup, DihedralGroup, AlternatingGroup + + """ + a = list(range(1, n)) + a.append(0) + gen = _af_new(a) + G = PermutationGroup([gen]) + + G._is_abelian = True + G._is_nilpotent = True + G._is_solvable = True + G._degree = n + G._is_transitive = True + G._order = n + G._is_dihedral = (n == 2) + return G + + +def DihedralGroup(n): + r""" + Generates the dihedral group `D_n` as a permutation group. + + Explanation + =========== + + The dihedral group `D_n` is the group of symmetries of the regular + ``n``-gon. The generators taken are the ``n``-cycle ``a = (0 1 2 ... n-1)`` + (a rotation of the ``n``-gon) and ``b = (0 n-1)(1 n-2)...`` + (a reflection of the ``n``-gon) in cycle rotation. It is easy to see that + these satisfy ``a**n = b**2 = 1`` and ``bab = ~a`` so they indeed generate + `D_n` (See [1]). After the group is generated, some of its basic properties + are set. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import DihedralGroup + >>> G = DihedralGroup(5) + >>> G.is_group + True + >>> a = list(G.generate_dimino()) + >>> [perm.cyclic_form for perm in a] + [[], [[0, 1, 2, 3, 4]], [[0, 2, 4, 1, 3]], + [[0, 3, 1, 4, 2]], [[0, 4, 3, 2, 1]], [[0, 4], [1, 3]], + [[1, 4], [2, 3]], [[0, 1], [2, 4]], [[0, 2], [3, 4]], + [[0, 3], [1, 2]]] + + See Also + ======== + + SymmetricGroup, CyclicGroup, AlternatingGroup + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Dihedral_group + + """ + # small cases are special + if n == 1: + return PermutationGroup([Permutation([1, 0])]) + if n == 2: + return PermutationGroup([Permutation([1, 0, 3, 2]), + Permutation([2, 3, 0, 1]), Permutation([3, 2, 1, 0])]) + + a = list(range(1, n)) + a.append(0) + gen1 = _af_new(a) + a = list(range(n)) + a.reverse() + gen2 = _af_new(a) + G = PermutationGroup([gen1, gen2]) + # if n is a power of 2, group is nilpotent + if n & (n-1) == 0: + G._is_nilpotent = True + else: + G._is_nilpotent = False + G._is_dihedral = True + G._is_abelian = False + G._is_solvable = True + G._degree = n + G._is_transitive = True + G._order = 2*n + return G + + +def SymmetricGroup(n): + """ + Generates the symmetric group on ``n`` elements as a permutation group. + + Explanation + =========== + + The generators taken are the ``n``-cycle + ``(0 1 2 ... n-1)`` and the transposition ``(0 1)`` (in cycle notation). + (See [1]). After the group is generated, some of its basic properties + are set. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import SymmetricGroup + >>> G = SymmetricGroup(4) + >>> G.is_group + True + >>> G.order() + 24 + >>> list(G.generate_schreier_sims(af=True)) + [[0, 1, 2, 3], [1, 2, 3, 0], [2, 3, 0, 1], [3, 1, 2, 0], [0, 2, 3, 1], + [1, 3, 0, 2], [2, 0, 1, 3], [3, 2, 0, 1], [0, 3, 1, 2], [1, 0, 2, 3], + [2, 1, 3, 0], [3, 0, 1, 2], [0, 1, 3, 2], [1, 2, 0, 3], [2, 3, 1, 0], + [3, 1, 0, 2], [0, 2, 1, 3], [1, 3, 2, 0], [2, 0, 3, 1], [3, 2, 1, 0], + [0, 3, 2, 1], [1, 0, 3, 2], [2, 1, 0, 3], [3, 0, 2, 1]] + + See Also + ======== + + CyclicGroup, DihedralGroup, AlternatingGroup + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Symmetric_group#Generators_and_relations + + """ + if n == 1: + G = PermutationGroup([Permutation([0])]) + elif n == 2: + G = PermutationGroup([Permutation([1, 0])]) + else: + a = list(range(1, n)) + a.append(0) + gen1 = _af_new(a) + a = list(range(n)) + a[0], a[1] = a[1], a[0] + gen2 = _af_new(a) + G = PermutationGroup([gen1, gen2]) + set_symmetric_group_properties(G, n, n) + G._is_sym = True + return G + + +def set_symmetric_group_properties(G, n, degree): + """Set known properties of a symmetric group. """ + if n < 3: + G._is_abelian = True + G._is_nilpotent = True + else: + G._is_abelian = False + G._is_nilpotent = False + if n < 5: + G._is_solvable = True + else: + G._is_solvable = False + G._degree = degree + G._is_transitive = True + G._is_dihedral = (n in [2, 3]) # cf Landau's func and Stirling's approx + + +def RubikGroup(n): + """Return a group of Rubik's cube generators + + >>> from sympy.combinatorics.named_groups import RubikGroup + >>> RubikGroup(2).is_group + True + """ + from sympy.combinatorics.generators import rubik + if n <= 1: + raise ValueError("Invalid cube. n has to be greater than 1") + return PermutationGroup(rubik(n)) diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/partitions.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/partitions.py new file mode 100644 index 0000000000000000000000000000000000000000..dfe646baabbb5bf2350cba859a265ac32bbfaf53 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/partitions.py @@ -0,0 +1,745 @@ +from sympy.core import Basic, Dict, sympify, Tuple +from sympy.core.numbers import Integer +from sympy.core.sorting import default_sort_key +from sympy.core.sympify import _sympify +from sympy.functions.combinatorial.numbers import bell +from sympy.matrices import zeros +from sympy.sets.sets import FiniteSet, Union +from sympy.utilities.iterables import flatten, group +from sympy.utilities.misc import as_int + + +from collections import defaultdict + + +class Partition(FiniteSet): + """ + This class represents an abstract partition. + + A partition is a set of disjoint sets whose union equals a given set. + + See Also + ======== + + sympy.utilities.iterables.partitions, + sympy.utilities.iterables.multiset_partitions + """ + + _rank = None + _partition = None + + def __new__(cls, *partition): + """ + Generates a new partition object. + + This method also verifies if the arguments passed are + valid and raises a ValueError if they are not. + + Examples + ======== + + Creating Partition from Python lists: + + >>> from sympy.combinatorics import Partition + >>> a = Partition([1, 2], [3]) + >>> a + Partition({3}, {1, 2}) + >>> a.partition + [[1, 2], [3]] + >>> len(a) + 2 + >>> a.members + (1, 2, 3) + + Creating Partition from Python sets: + + >>> Partition({1, 2, 3}, {4, 5}) + Partition({4, 5}, {1, 2, 3}) + + Creating Partition from SymPy finite sets: + + >>> from sympy import FiniteSet + >>> a = FiniteSet(1, 2, 3) + >>> b = FiniteSet(4, 5) + >>> Partition(a, b) + Partition({4, 5}, {1, 2, 3}) + """ + args = [] + dups = False + for arg in partition: + if isinstance(arg, list): + as_set = set(arg) + if len(as_set) < len(arg): + dups = True + break # error below + arg = as_set + args.append(_sympify(arg)) + + if not all(isinstance(part, FiniteSet) for part in args): + raise ValueError( + "Each argument to Partition should be " \ + "a list, set, or a FiniteSet") + + # sort so we have a canonical reference for RGS + U = Union(*args) + if dups or len(U) < sum(len(arg) for arg in args): + raise ValueError("Partition contained duplicate elements.") + + obj = FiniteSet.__new__(cls, *args) + obj.members = tuple(U) + obj.size = len(U) + return obj + + def sort_key(self, order=None): + """Return a canonical key that can be used for sorting. + + Ordering is based on the size and sorted elements of the partition + and ties are broken with the rank. + + Examples + ======== + + >>> from sympy import default_sort_key + >>> from sympy.combinatorics import Partition + >>> from sympy.abc import x + >>> a = Partition([1, 2]) + >>> b = Partition([3, 4]) + >>> c = Partition([1, x]) + >>> d = Partition(list(range(4))) + >>> l = [d, b, a + 1, a, c] + >>> l.sort(key=default_sort_key); l + [Partition({1, 2}), Partition({1}, {2}), Partition({1, x}), Partition({3, 4}), Partition({0, 1, 2, 3})] + """ + if order is None: + members = self.members + else: + members = tuple(sorted(self.members, + key=lambda w: default_sort_key(w, order))) + return tuple(map(default_sort_key, (self.size, members, self.rank))) + + @property + def partition(self): + """Return partition as a sorted list of lists. + + Examples + ======== + + >>> from sympy.combinatorics import Partition + >>> Partition([1], [2, 3]).partition + [[1], [2, 3]] + """ + if self._partition is None: + self._partition = sorted([sorted(p, key=default_sort_key) + for p in self.args]) + return self._partition + + def __add__(self, other): + """ + Return permutation whose rank is ``other`` greater than current rank, + (mod the maximum rank for the set). + + Examples + ======== + + >>> from sympy.combinatorics import Partition + >>> a = Partition([1, 2], [3]) + >>> a.rank + 1 + >>> (a + 1).rank + 2 + >>> (a + 100).rank + 1 + """ + other = as_int(other) + offset = self.rank + other + result = RGS_unrank((offset) % + RGS_enum(self.size), + self.size) + return Partition.from_rgs(result, self.members) + + def __sub__(self, other): + """ + Return permutation whose rank is ``other`` less than current rank, + (mod the maximum rank for the set). + + Examples + ======== + + >>> from sympy.combinatorics import Partition + >>> a = Partition([1, 2], [3]) + >>> a.rank + 1 + >>> (a - 1).rank + 0 + >>> (a - 100).rank + 1 + """ + return self.__add__(-other) + + def __le__(self, other): + """ + Checks if a partition is less than or equal to + the other based on rank. + + Examples + ======== + + >>> from sympy.combinatorics import Partition + >>> a = Partition([1, 2], [3, 4, 5]) + >>> b = Partition([1], [2, 3], [4], [5]) + >>> a.rank, b.rank + (9, 34) + >>> a <= a + True + >>> a <= b + True + """ + return self.sort_key() <= sympify(other).sort_key() + + def __lt__(self, other): + """ + Checks if a partition is less than the other. + + Examples + ======== + + >>> from sympy.combinatorics import Partition + >>> a = Partition([1, 2], [3, 4, 5]) + >>> b = Partition([1], [2, 3], [4], [5]) + >>> a.rank, b.rank + (9, 34) + >>> a < b + True + """ + return self.sort_key() < sympify(other).sort_key() + + @property + def rank(self): + """ + Gets the rank of a partition. + + Examples + ======== + + >>> from sympy.combinatorics import Partition + >>> a = Partition([1, 2], [3], [4, 5]) + >>> a.rank + 13 + """ + if self._rank is not None: + return self._rank + self._rank = RGS_rank(self.RGS) + return self._rank + + @property + def RGS(self): + """ + Returns the "restricted growth string" of the partition. + + Explanation + =========== + + The RGS is returned as a list of indices, L, where L[i] indicates + the block in which element i appears. For example, in a partition + of 3 elements (a, b, c) into 2 blocks ([c], [a, b]) the RGS is + [1, 1, 0]: "a" is in block 1, "b" is in block 1 and "c" is in block 0. + + Examples + ======== + + >>> from sympy.combinatorics import Partition + >>> a = Partition([1, 2], [3], [4, 5]) + >>> a.members + (1, 2, 3, 4, 5) + >>> a.RGS + (0, 0, 1, 2, 2) + >>> a + 1 + Partition({3}, {4}, {5}, {1, 2}) + >>> _.RGS + (0, 0, 1, 2, 3) + """ + rgs = {} + partition = self.partition + for i, part in enumerate(partition): + for j in part: + rgs[j] = i + return tuple([rgs[i] for i in sorted( + [i for p in partition for i in p], key=default_sort_key)]) + + @classmethod + def from_rgs(self, rgs, elements): + """ + Creates a set partition from a restricted growth string. + + Explanation + =========== + + The indices given in rgs are assumed to be the index + of the element as given in elements *as provided* (the + elements are not sorted by this routine). Block numbering + starts from 0. If any block was not referenced in ``rgs`` + an error will be raised. + + Examples + ======== + + >>> from sympy.combinatorics import Partition + >>> Partition.from_rgs([0, 1, 2, 0, 1], list('abcde')) + Partition({c}, {a, d}, {b, e}) + >>> Partition.from_rgs([0, 1, 2, 0, 1], list('cbead')) + Partition({e}, {a, c}, {b, d}) + >>> a = Partition([1, 4], [2], [3, 5]) + >>> Partition.from_rgs(a.RGS, a.members) + Partition({2}, {1, 4}, {3, 5}) + """ + if len(rgs) != len(elements): + raise ValueError('mismatch in rgs and element lengths') + max_elem = max(rgs) + 1 + partition = [[] for i in range(max_elem)] + j = 0 + for i in rgs: + partition[i].append(elements[j]) + j += 1 + if not all(p for p in partition): + raise ValueError('some blocks of the partition were empty.') + return Partition(*partition) + + +class IntegerPartition(Basic): + """ + This class represents an integer partition. + + Explanation + =========== + + In number theory and combinatorics, a partition of a positive integer, + ``n``, also called an integer partition, is a way of writing ``n`` as a + list of positive integers that sum to n. Two partitions that differ only + in the order of summands are considered to be the same partition; if order + matters then the partitions are referred to as compositions. For example, + 4 has five partitions: [4], [3, 1], [2, 2], [2, 1, 1], and [1, 1, 1, 1]; + the compositions [1, 2, 1] and [1, 1, 2] are the same as partition + [2, 1, 1]. + + See Also + ======== + + sympy.utilities.iterables.partitions, + sympy.utilities.iterables.multiset_partitions + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Partition_%28number_theory%29 + """ + + _dict = None + _keys = None + + def __new__(cls, partition, integer=None): + """ + Generates a new IntegerPartition object from a list or dictionary. + + Explanation + =========== + + The partition can be given as a list of positive integers or a + dictionary of (integer, multiplicity) items. If the partition is + preceded by an integer an error will be raised if the partition + does not sum to that given integer. + + Examples + ======== + + >>> from sympy.combinatorics.partitions import IntegerPartition + >>> a = IntegerPartition([5, 4, 3, 1, 1]) + >>> a + IntegerPartition(14, (5, 4, 3, 1, 1)) + >>> print(a) + [5, 4, 3, 1, 1] + >>> IntegerPartition({1:3, 2:1}) + IntegerPartition(5, (2, 1, 1, 1)) + + If the value that the partition should sum to is given first, a check + will be made to see n error will be raised if there is a discrepancy: + + >>> IntegerPartition(10, [5, 4, 3, 1]) + Traceback (most recent call last): + ... + ValueError: The partition is not valid + + """ + if integer is not None: + integer, partition = partition, integer + if isinstance(partition, (dict, Dict)): + _ = [] + for k, v in sorted(partition.items(), reverse=True): + if not v: + continue + k, v = as_int(k), as_int(v) + _.extend([k]*v) + partition = tuple(_) + else: + partition = tuple(sorted(map(as_int, partition), reverse=True)) + sum_ok = False + if integer is None: + integer = sum(partition) + sum_ok = True + else: + integer = as_int(integer) + + if not sum_ok and sum(partition) != integer: + raise ValueError("Partition did not add to %s" % integer) + if any(i < 1 for i in partition): + raise ValueError("All integer summands must be greater than one") + + obj = Basic.__new__(cls, Integer(integer), Tuple(*partition)) + obj.partition = list(partition) + obj.integer = integer + return obj + + def prev_lex(self): + """Return the previous partition of the integer, n, in lexical order, + wrapping around to [1, ..., 1] if the partition is [n]. + + Examples + ======== + + >>> from sympy.combinatorics.partitions import IntegerPartition + >>> p = IntegerPartition([4]) + >>> print(p.prev_lex()) + [3, 1] + >>> p.partition > p.prev_lex().partition + True + """ + d = defaultdict(int) + d.update(self.as_dict()) + keys = self._keys + if keys == [1]: + return IntegerPartition({self.integer: 1}) + if keys[-1] != 1: + d[keys[-1]] -= 1 + if keys[-1] == 2: + d[1] = 2 + else: + d[keys[-1] - 1] = d[1] = 1 + else: + d[keys[-2]] -= 1 + left = d[1] + keys[-2] + new = keys[-2] + d[1] = 0 + while left: + new -= 1 + if left - new >= 0: + d[new] += left//new + left -= d[new]*new + return IntegerPartition(self.integer, d) + + def next_lex(self): + """Return the next partition of the integer, n, in lexical order, + wrapping around to [n] if the partition is [1, ..., 1]. + + Examples + ======== + + >>> from sympy.combinatorics.partitions import IntegerPartition + >>> p = IntegerPartition([3, 1]) + >>> print(p.next_lex()) + [4] + >>> p.partition < p.next_lex().partition + True + """ + d = defaultdict(int) + d.update(self.as_dict()) + key = self._keys + a = key[-1] + if a == self.integer: + d.clear() + d[1] = self.integer + elif a == 1: + if d[a] > 1: + d[a + 1] += 1 + d[a] -= 2 + else: + b = key[-2] + d[b + 1] += 1 + d[1] = (d[b] - 1)*b + d[b] = 0 + else: + if d[a] > 1: + if len(key) == 1: + d.clear() + d[a + 1] = 1 + d[1] = self.integer - a - 1 + else: + a1 = a + 1 + d[a1] += 1 + d[1] = d[a]*a - a1 + d[a] = 0 + else: + b = key[-2] + b1 = b + 1 + d[b1] += 1 + need = d[b]*b + d[a]*a - b1 + d[a] = d[b] = 0 + d[1] = need + return IntegerPartition(self.integer, d) + + def as_dict(self): + """Return the partition as a dictionary whose keys are the + partition integers and the values are the multiplicity of that + integer. + + Examples + ======== + + >>> from sympy.combinatorics.partitions import IntegerPartition + >>> IntegerPartition([1]*3 + [2] + [3]*4).as_dict() + {1: 3, 2: 1, 3: 4} + """ + if self._dict is None: + groups = group(self.partition, multiple=False) + self._keys = [g[0] for g in groups] + self._dict = dict(groups) + return self._dict + + @property + def conjugate(self): + """ + Computes the conjugate partition of itself. + + Examples + ======== + + >>> from sympy.combinatorics.partitions import IntegerPartition + >>> a = IntegerPartition([6, 3, 3, 2, 1]) + >>> a.conjugate + [5, 4, 3, 1, 1, 1] + """ + j = 1 + temp_arr = list(self.partition) + [0] + k = temp_arr[0] + b = [0]*k + while k > 0: + while k > temp_arr[j]: + b[k - 1] = j + k -= 1 + j += 1 + return b + + def __lt__(self, other): + """Return True if self is less than other when the partition + is listed from smallest to biggest. + + Examples + ======== + + >>> from sympy.combinatorics.partitions import IntegerPartition + >>> a = IntegerPartition([3, 1]) + >>> a < a + False + >>> b = a.next_lex() + >>> a < b + True + >>> a == b + False + """ + return list(reversed(self.partition)) < list(reversed(other.partition)) + + def __le__(self, other): + """Return True if self is less than other when the partition + is listed from smallest to biggest. + + Examples + ======== + + >>> from sympy.combinatorics.partitions import IntegerPartition + >>> a = IntegerPartition([4]) + >>> a <= a + True + """ + return list(reversed(self.partition)) <= list(reversed(other.partition)) + + def as_ferrers(self, char='#'): + """ + Prints the ferrer diagram of a partition. + + Examples + ======== + + >>> from sympy.combinatorics.partitions import IntegerPartition + >>> print(IntegerPartition([1, 1, 5]).as_ferrers()) + ##### + # + # + """ + return "\n".join([char*i for i in self.partition]) + + def __str__(self): + return str(list(self.partition)) + + +def random_integer_partition(n, seed=None): + """ + Generates a random integer partition summing to ``n`` as a list + of reverse-sorted integers. + + Examples + ======== + + >>> from sympy.combinatorics.partitions import random_integer_partition + + For the following, a seed is given so a known value can be shown; in + practice, the seed would not be given. + + >>> random_integer_partition(100, seed=[1, 1, 12, 1, 2, 1, 85, 1]) + [85, 12, 2, 1] + >>> random_integer_partition(10, seed=[1, 2, 3, 1, 5, 1]) + [5, 3, 1, 1] + >>> random_integer_partition(1) + [1] + """ + from sympy.core.random import _randint + + n = as_int(n) + if n < 1: + raise ValueError('n must be a positive integer') + + randint = _randint(seed) + + partition = [] + while (n > 0): + k = randint(1, n) + mult = randint(1, n//k) + partition.append((k, mult)) + n -= k*mult + partition.sort(reverse=True) + partition = flatten([[k]*m for k, m in partition]) + return partition + + +def RGS_generalized(m): + """ + Computes the m + 1 generalized unrestricted growth strings + and returns them as rows in matrix. + + Examples + ======== + + >>> from sympy.combinatorics.partitions import RGS_generalized + >>> RGS_generalized(6) + Matrix([ + [ 1, 1, 1, 1, 1, 1, 1], + [ 1, 2, 3, 4, 5, 6, 0], + [ 2, 5, 10, 17, 26, 0, 0], + [ 5, 15, 37, 77, 0, 0, 0], + [ 15, 52, 151, 0, 0, 0, 0], + [ 52, 203, 0, 0, 0, 0, 0], + [203, 0, 0, 0, 0, 0, 0]]) + """ + d = zeros(m + 1) + for i in range(m + 1): + d[0, i] = 1 + + for i in range(1, m + 1): + for j in range(m): + if j <= m - i: + d[i, j] = j * d[i - 1, j] + d[i - 1, j + 1] + else: + d[i, j] = 0 + return d + + +def RGS_enum(m): + """ + RGS_enum computes the total number of restricted growth strings + possible for a superset of size m. + + Examples + ======== + + >>> from sympy.combinatorics.partitions import RGS_enum + >>> from sympy.combinatorics import Partition + >>> RGS_enum(4) + 15 + >>> RGS_enum(5) + 52 + >>> RGS_enum(6) + 203 + + We can check that the enumeration is correct by actually generating + the partitions. Here, the 15 partitions of 4 items are generated: + + >>> a = Partition(list(range(4))) + >>> s = set() + >>> for i in range(20): + ... s.add(a) + ... a += 1 + ... + >>> assert len(s) == 15 + + """ + if (m < 1): + return 0 + elif (m == 1): + return 1 + else: + return bell(m) + + +def RGS_unrank(rank, m): + """ + Gives the unranked restricted growth string for a given + superset size. + + Examples + ======== + + >>> from sympy.combinatorics.partitions import RGS_unrank + >>> RGS_unrank(14, 4) + [0, 1, 2, 3] + >>> RGS_unrank(0, 4) + [0, 0, 0, 0] + """ + if m < 1: + raise ValueError("The superset size must be >= 1") + if rank < 0 or RGS_enum(m) <= rank: + raise ValueError("Invalid arguments") + + L = [1] * (m + 1) + j = 1 + D = RGS_generalized(m) + for i in range(2, m + 1): + v = D[m - i, j] + cr = j*v + if cr <= rank: + L[i] = j + 1 + rank -= cr + j += 1 + else: + L[i] = int(rank / v + 1) + rank %= v + return [x - 1 for x in L[1:]] + + +def RGS_rank(rgs): + """ + Computes the rank of a restricted growth string. + + Examples + ======== + + >>> from sympy.combinatorics.partitions import RGS_rank, RGS_unrank + >>> RGS_rank([0, 1, 2, 1, 3]) + 42 + >>> RGS_rank(RGS_unrank(4, 7)) + 4 + """ + rgs_size = len(rgs) + rank = 0 + D = RGS_generalized(rgs_size) + for i in range(1, rgs_size): + n = len(rgs[(i + 1):]) + m = max(rgs[0:i]) + rank += D[n, m + 1] * rgs[i] + return rank diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/pc_groups.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/pc_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..abf7b82258d8bb61ba350509fcbb3110a035fc88 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/pc_groups.py @@ -0,0 +1,710 @@ +from sympy.ntheory.primetest import isprime +from sympy.combinatorics.perm_groups import PermutationGroup +from sympy.printing.defaults import DefaultPrinting +from sympy.combinatorics.free_groups import free_group + + +class PolycyclicGroup(DefaultPrinting): + + is_group = True + is_solvable = True + + def __init__(self, pc_sequence, pc_series, relative_order, collector=None): + """ + + Parameters + ========== + + pc_sequence : list + A sequence of elements whose classes generate the cyclic factor + groups of pc_series. + pc_series : list + A subnormal sequence of subgroups where each factor group is cyclic. + relative_order : list + The orders of factor groups of pc_series. + collector : Collector + By default, it is None. Collector class provides the + polycyclic presentation with various other functionalities. + + """ + self.pcgs = pc_sequence + self.pc_series = pc_series + self.relative_order = relative_order + self.collector = Collector(self.pcgs, pc_series, relative_order) if not collector else collector + + def is_prime_order(self): + return all(isprime(order) for order in self.relative_order) + + def length(self): + return len(self.pcgs) + + +class Collector(DefaultPrinting): + + """ + References + ========== + + .. [1] Holt, D., Eick, B., O'Brien, E. + "Handbook of Computational Group Theory" + Section 8.1.3 + """ + + def __init__(self, pcgs, pc_series, relative_order, free_group_=None, pc_presentation=None): + """ + + Most of the parameters for the Collector class are the same as for PolycyclicGroup. + Others are described below. + + Parameters + ========== + + free_group_ : tuple + free_group_ provides the mapping of polycyclic generating + sequence with the free group elements. + pc_presentation : dict + Provides the presentation of polycyclic groups with the + help of power and conjugate relators. + + See Also + ======== + + PolycyclicGroup + + """ + self.pcgs = pcgs + self.pc_series = pc_series + self.relative_order = relative_order + self.free_group = free_group('x:{}'.format(len(pcgs)))[0] if not free_group_ else free_group_ + self.index = {s: i for i, s in enumerate(self.free_group.symbols)} + self.pc_presentation = self.pc_relators() + + def minimal_uncollected_subword(self, word): + r""" + Returns the minimal uncollected subwords. + + Explanation + =========== + + A word ``v`` defined on generators in ``X`` is a minimal + uncollected subword of the word ``w`` if ``v`` is a subword + of ``w`` and it has one of the following form + + * `v = {x_{i+1}}^{a_j}x_i` + + * `v = {x_{i+1}}^{a_j}{x_i}^{-1}` + + * `v = {x_i}^{a_j}` + + for `a_j` not in `\{1, \ldots, s-1\}`. Where, ``s`` is the power + exponent of the corresponding generator. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import SymmetricGroup + >>> from sympy.combinatorics import free_group + >>> G = SymmetricGroup(4) + >>> PcGroup = G.polycyclic_group() + >>> collector = PcGroup.collector + >>> F, x1, x2 = free_group("x1, x2") + >>> word = x2**2*x1**7 + >>> collector.minimal_uncollected_subword(word) + ((x2, 2),) + + """ + # To handle the case word = + if not word: + return None + + array = word.array_form + re = self.relative_order + index = self.index + + for i in range(len(array)): + s1, e1 = array[i] + + if re[index[s1]] and (e1 < 0 or e1 > re[index[s1]]-1): + return ((s1, e1), ) + + for i in range(len(array)-1): + s1, e1 = array[i] + s2, e2 = array[i+1] + + if index[s1] > index[s2]: + e = 1 if e2 > 0 else -1 + return ((s1, e1), (s2, e)) + + return None + + def relations(self): + """ + Separates the given relators of pc presentation in power and + conjugate relations. + + Returns + ======= + + (power_rel, conj_rel) + Separates pc presentation into power and conjugate relations. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import SymmetricGroup + >>> G = SymmetricGroup(3) + >>> PcGroup = G.polycyclic_group() + >>> collector = PcGroup.collector + >>> power_rel, conj_rel = collector.relations() + >>> power_rel + {x0**2: (), x1**3: ()} + >>> conj_rel + {x0**-1*x1*x0: x1**2} + + See Also + ======== + + pc_relators + + """ + power_relators = {} + conjugate_relators = {} + for key, value in self.pc_presentation.items(): + if len(key.array_form) == 1: + power_relators[key] = value + else: + conjugate_relators[key] = value + return power_relators, conjugate_relators + + def subword_index(self, word, w): + """ + Returns the start and ending index of a given + subword in a word. + + Parameters + ========== + + word : FreeGroupElement + word defined on free group elements for a + polycyclic group. + w : FreeGroupElement + subword of a given word, whose starting and + ending index to be computed. + + Returns + ======= + + (i, j) + A tuple containing starting and ending index of ``w`` + in the given word. If not exists, (-1,-1) is returned. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import SymmetricGroup + >>> from sympy.combinatorics import free_group + >>> G = SymmetricGroup(4) + >>> PcGroup = G.polycyclic_group() + >>> collector = PcGroup.collector + >>> F, x1, x2 = free_group("x1, x2") + >>> word = x2**2*x1**7 + >>> w = x2**2*x1 + >>> collector.subword_index(word, w) + (0, 3) + >>> w = x1**7 + >>> collector.subword_index(word, w) + (2, 9) + >>> w = x1**8 + >>> collector.subword_index(word, w) + (-1, -1) + + """ + low = -1 + high = -1 + for i in range(len(word)-len(w)+1): + if word.subword(i, i+len(w)) == w: + low = i + high = i+len(w) + break + return low, high + + def map_relation(self, w): + """ + Return a conjugate relation. + + Explanation + =========== + + Given a word formed by two free group elements, the + corresponding conjugate relation with those free + group elements is formed and mapped with the collected + word in the polycyclic presentation. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import SymmetricGroup + >>> from sympy.combinatorics import free_group + >>> G = SymmetricGroup(3) + >>> PcGroup = G.polycyclic_group() + >>> collector = PcGroup.collector + >>> F, x0, x1 = free_group("x0, x1") + >>> w = x1*x0 + >>> collector.map_relation(w) + x1**2 + + See Also + ======== + + pc_presentation + + """ + array = w.array_form + s1 = array[0][0] + s2 = array[1][0] + key = ((s2, -1), (s1, 1), (s2, 1)) + key = self.free_group.dtype(key) + return self.pc_presentation[key] + + + def collected_word(self, word): + r""" + Return the collected form of a word. + + Explanation + =========== + + A word ``w`` is called collected, if `w = {x_{i_1}}^{a_1} * \ldots * + {x_{i_r}}^{a_r}` with `i_1 < i_2< \ldots < i_r` and `a_j` is in + `\{1, \ldots, {s_j}-1\}`. + + Otherwise w is uncollected. + + Parameters + ========== + + word : FreeGroupElement + An uncollected word. + + Returns + ======= + + word + A collected word of form `w = {x_{i_1}}^{a_1}, \ldots, + {x_{i_r}}^{a_r}` with `i_1, i_2, \ldots, i_r` and `a_j \in + \{1, \ldots, {s_j}-1\}`. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import SymmetricGroup + >>> from sympy.combinatorics.perm_groups import PermutationGroup + >>> from sympy.combinatorics import free_group + >>> G = SymmetricGroup(4) + >>> PcGroup = G.polycyclic_group() + >>> collector = PcGroup.collector + >>> F, x0, x1, x2, x3 = free_group("x0, x1, x2, x3") + >>> word = x3*x2*x1*x0 + >>> collected_word = collector.collected_word(word) + >>> free_to_perm = {} + >>> free_group = collector.free_group + >>> for sym, gen in zip(free_group.symbols, collector.pcgs): + ... free_to_perm[sym] = gen + >>> G1 = PermutationGroup() + >>> for w in word: + ... sym = w[0] + ... perm = free_to_perm[sym] + ... G1 = PermutationGroup([perm] + G1.generators) + >>> G2 = PermutationGroup() + >>> for w in collected_word: + ... sym = w[0] + ... perm = free_to_perm[sym] + ... G2 = PermutationGroup([perm] + G2.generators) + + The two are not identical, but they are equivalent: + + >>> G1.equals(G2), G1 == G2 + (True, False) + + See Also + ======== + + minimal_uncollected_subword + + """ + free_group = self.free_group + while True: + w = self.minimal_uncollected_subword(word) + if not w: + break + + low, high = self.subword_index(word, free_group.dtype(w)) + if low == -1: + continue + + s1, e1 = w[0] + if len(w) == 1: + re = self.relative_order[self.index[s1]] + q = e1 // re + r = e1-q*re + + key = ((w[0][0], re), ) + key = free_group.dtype(key) + if self.pc_presentation[key]: + presentation = self.pc_presentation[key].array_form + sym, exp = presentation[0] + word_ = ((w[0][0], r), (sym, q*exp)) + word_ = free_group.dtype(word_) + else: + if r != 0: + word_ = ((w[0][0], r), ) + word_ = free_group.dtype(word_) + else: + word_ = None + word = word.eliminate_word(free_group.dtype(w), word_) + + if len(w) == 2 and w[1][1] > 0: + s2, e2 = w[1] + s2 = ((s2, 1), ) + s2 = free_group.dtype(s2) + word_ = self.map_relation(free_group.dtype(w)) + word_ = s2*word_**e1 + word_ = free_group.dtype(word_) + word = word.substituted_word(low, high, word_) + + elif len(w) == 2 and w[1][1] < 0: + s2, e2 = w[1] + s2 = ((s2, 1), ) + s2 = free_group.dtype(s2) + word_ = self.map_relation(free_group.dtype(w)) + word_ = s2**-1*word_**e1 + word_ = free_group.dtype(word_) + word = word.substituted_word(low, high, word_) + + return word + + + def pc_relators(self): + r""" + Return the polycyclic presentation. + + Explanation + =========== + + There are two types of relations used in polycyclic + presentation. + + * Power relations : Power relators are of the form `x_i^{re_i}`, + where `i \in \{0, \ldots, \mathrm{len(pcgs)}\}`, ``x`` represents polycyclic + generator and ``re`` is the corresponding relative order. + + * Conjugate relations : Conjugate relators are of the form `x_j^-1x_ix_j`, + where `j < i \in \{0, \ldots, \mathrm{len(pcgs)}\}`. + + Returns + ======= + + A dictionary with power and conjugate relations as key and + their collected form as corresponding values. + + Notes + ===== + + Identity Permutation is mapped with empty ``()``. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import SymmetricGroup + >>> from sympy.combinatorics.permutations import Permutation + >>> S = SymmetricGroup(49).sylow_subgroup(7) + >>> der = S.derived_series() + >>> G = der[len(der)-2] + >>> PcGroup = G.polycyclic_group() + >>> collector = PcGroup.collector + >>> pcgs = PcGroup.pcgs + >>> len(pcgs) + 6 + >>> free_group = collector.free_group + >>> pc_resentation = collector.pc_presentation + >>> free_to_perm = {} + >>> for s, g in zip(free_group.symbols, pcgs): + ... free_to_perm[s] = g + + >>> for k, v in pc_resentation.items(): + ... k_array = k.array_form + ... if v != (): + ... v_array = v.array_form + ... lhs = Permutation() + ... for gen in k_array: + ... s = gen[0] + ... e = gen[1] + ... lhs = lhs*free_to_perm[s]**e + ... if v == (): + ... assert lhs.is_identity + ... continue + ... rhs = Permutation() + ... for gen in v_array: + ... s = gen[0] + ... e = gen[1] + ... rhs = rhs*free_to_perm[s]**e + ... assert lhs == rhs + + """ + free_group = self.free_group + rel_order = self.relative_order + pc_relators = {} + perm_to_free = {} + pcgs = self.pcgs + + for gen, s in zip(pcgs, free_group.generators): + perm_to_free[gen**-1] = s**-1 + perm_to_free[gen] = s + + pcgs = pcgs[::-1] + series = self.pc_series[::-1] + rel_order = rel_order[::-1] + collected_gens = [] + + for i, gen in enumerate(pcgs): + re = rel_order[i] + relation = perm_to_free[gen]**re + G = series[i] + + l = G.generator_product(gen**re, original = True) + l.reverse() + + word = free_group.identity + for g in l: + word = word*perm_to_free[g] + + word = self.collected_word(word) + pc_relators[relation] = word if word else () + self.pc_presentation = pc_relators + + collected_gens.append(gen) + if len(collected_gens) > 1: + conj = collected_gens[len(collected_gens)-1] + conjugator = perm_to_free[conj] + + for j in range(len(collected_gens)-1): + conjugated = perm_to_free[collected_gens[j]] + + relation = conjugator**-1*conjugated*conjugator + gens = conj**-1*collected_gens[j]*conj + + l = G.generator_product(gens, original = True) + l.reverse() + word = free_group.identity + for g in l: + word = word*perm_to_free[g] + + word = self.collected_word(word) + pc_relators[relation] = word if word else () + self.pc_presentation = pc_relators + + return pc_relators + + def exponent_vector(self, element): + r""" + Return the exponent vector of length equal to the + length of polycyclic generating sequence. + + Explanation + =========== + + For a given generator/element ``g`` of the polycyclic group, + it can be represented as `g = {x_1}^{e_1}, \ldots, {x_n}^{e_n}`, + where `x_i` represents polycyclic generators and ``n`` is + the number of generators in the free_group equal to the length + of pcgs. + + Parameters + ========== + + element : Permutation + Generator of a polycyclic group. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import SymmetricGroup + >>> from sympy.combinatorics.permutations import Permutation + >>> G = SymmetricGroup(4) + >>> PcGroup = G.polycyclic_group() + >>> collector = PcGroup.collector + >>> pcgs = PcGroup.pcgs + >>> collector.exponent_vector(G[0]) + [1, 0, 0, 0] + >>> exp = collector.exponent_vector(G[1]) + >>> g = Permutation() + >>> for i in range(len(exp)): + ... g = g*pcgs[i]**exp[i] if exp[i] else g + >>> assert g == G[1] + + References + ========== + + .. [1] Holt, D., Eick, B., O'Brien, E. + "Handbook of Computational Group Theory" + Section 8.1.1, Definition 8.4 + + """ + free_group = self.free_group + G = PermutationGroup() + for g in self.pcgs: + G = PermutationGroup([g] + G.generators) + gens = G.generator_product(element, original = True) + gens.reverse() + + perm_to_free = {} + for sym, g in zip(free_group.generators, self.pcgs): + perm_to_free[g**-1] = sym**-1 + perm_to_free[g] = sym + w = free_group.identity + for g in gens: + w = w*perm_to_free[g] + + word = self.collected_word(w) + + index = self.index + exp_vector = [0]*len(free_group) + word = word.array_form + for t in word: + exp_vector[index[t[0]]] = t[1] + return exp_vector + + def depth(self, element): + r""" + Return the depth of a given element. + + Explanation + =========== + + The depth of a given element ``g`` is defined by + `\mathrm{dep}[g] = i` if `e_1 = e_2 = \ldots = e_{i-1} = 0` + and `e_i != 0`, where ``e`` represents the exponent-vector. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import SymmetricGroup + >>> G = SymmetricGroup(3) + >>> PcGroup = G.polycyclic_group() + >>> collector = PcGroup.collector + >>> collector.depth(G[0]) + 2 + >>> collector.depth(G[1]) + 1 + + References + ========== + + .. [1] Holt, D., Eick, B., O'Brien, E. + "Handbook of Computational Group Theory" + Section 8.1.1, Definition 8.5 + + """ + exp_vector = self.exponent_vector(element) + return next((i+1 for i, x in enumerate(exp_vector) if x), len(self.pcgs)+1) + + def leading_exponent(self, element): + r""" + Return the leading non-zero exponent. + + Explanation + =========== + + The leading exponent for a given element `g` is defined + by `\mathrm{leading\_exponent}[g]` `= e_i`, if `\mathrm{depth}[g] = i`. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import SymmetricGroup + >>> G = SymmetricGroup(3) + >>> PcGroup = G.polycyclic_group() + >>> collector = PcGroup.collector + >>> collector.leading_exponent(G[1]) + 1 + + """ + exp_vector = self.exponent_vector(element) + depth = self.depth(element) + if depth != len(self.pcgs)+1: + return exp_vector[depth-1] + return None + + def _sift(self, z, g): + h = g + d = self.depth(h) + while d < len(self.pcgs) and z[d-1] != 1: + k = z[d-1] + e = self.leading_exponent(h)*(self.leading_exponent(k))**-1 + e = e % self.relative_order[d-1] + h = k**-e*h + d = self.depth(h) + return h + + def induced_pcgs(self, gens): + """ + + Parameters + ========== + + gens : list + A list of generators on which polycyclic subgroup + is to be defined. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import SymmetricGroup + >>> S = SymmetricGroup(8) + >>> G = S.sylow_subgroup(2) + >>> PcGroup = G.polycyclic_group() + >>> collector = PcGroup.collector + >>> gens = [G[0], G[1]] + >>> ipcgs = collector.induced_pcgs(gens) + >>> [gen.order() for gen in ipcgs] + [2, 2, 2] + >>> G = S.sylow_subgroup(3) + >>> PcGroup = G.polycyclic_group() + >>> collector = PcGroup.collector + >>> gens = [G[0], G[1]] + >>> ipcgs = collector.induced_pcgs(gens) + >>> [gen.order() for gen in ipcgs] + [3] + + """ + z = [1]*len(self.pcgs) + G = gens + while G: + g = G.pop(0) + h = self._sift(z, g) + d = self.depth(h) + if d < len(self.pcgs): + for gen in z: + if gen != 1: + G.append(h**-1*gen**-1*h*gen) + z[d-1] = h + z = [gen for gen in z if gen != 1] + return z + + def constructive_membership_test(self, ipcgs, g): + """ + Return the exponent vector for induced pcgs. + """ + e = [0]*len(ipcgs) + h = g + d = self.depth(h) + for i, gen in enumerate(ipcgs): + while self.depth(gen) == d: + f = self.leading_exponent(h)*self.leading_exponent(gen) + f = f % self.relative_order[d-1] + h = gen**(-f)*h + e[i] = f + d = self.depth(h) + if h == 1: + return e + return False diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/perm_groups.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/perm_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..24359cb546246d63470de92b2fb2ab0804fc9886 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/perm_groups.py @@ -0,0 +1,5459 @@ +from math import factorial as _factorial, log, prod +from itertools import chain, product + + +from sympy.combinatorics import Permutation +from sympy.combinatorics.permutations import (_af_commutes_with, _af_invert, + _af_rmul, _af_rmuln, _af_pow, Cycle) +from sympy.combinatorics.util import (_check_cycles_alt_sym, + _distribute_gens_by_base, _orbits_transversals_from_bsgs, + _handle_precomputed_bsgs, _base_ordering, _strong_gens_from_distr, + _strip, _strip_af) +from sympy.core import Basic +from sympy.core.random import _randrange, randrange, choice +from sympy.core.symbol import Symbol +from sympy.core.sympify import _sympify +from sympy.functions.combinatorial.factorials import factorial +from sympy.ntheory import primefactors, sieve +from sympy.ntheory.factor_ import (factorint, multiplicity) +from sympy.ntheory.primetest import isprime +from sympy.utilities.iterables import has_variety, is_sequence, uniq + +rmul = Permutation.rmul_with_af +_af_new = Permutation._af_new + + +class PermutationGroup(Basic): + r"""The class defining a Permutation group. + + Explanation + =========== + + ``PermutationGroup([p1, p2, ..., pn])`` returns the permutation group + generated by the list of permutations. This group can be supplied + to Polyhedron if one desires to decorate the elements to which the + indices of the permutation refer. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> from sympy.combinatorics import Polyhedron + + The permutations corresponding to motion of the front, right and + bottom face of a $2 \times 2$ Rubik's cube are defined: + + >>> F = Permutation(2, 19, 21, 8)(3, 17, 20, 10)(4, 6, 7, 5) + >>> R = Permutation(1, 5, 21, 14)(3, 7, 23, 12)(8, 10, 11, 9) + >>> D = Permutation(6, 18, 14, 10)(7, 19, 15, 11)(20, 22, 23, 21) + + These are passed as permutations to PermutationGroup: + + >>> G = PermutationGroup(F, R, D) + >>> G.order() + 3674160 + + The group can be supplied to a Polyhedron in order to track the + objects being moved. An example involving the $2 \times 2$ Rubik's cube is + given there, but here is a simple demonstration: + + >>> a = Permutation(2, 1) + >>> b = Permutation(1, 0) + >>> G = PermutationGroup(a, b) + >>> P = Polyhedron(list('ABC'), pgroup=G) + >>> P.corners + (A, B, C) + >>> P.rotate(0) # apply permutation 0 + >>> P.corners + (A, C, B) + >>> P.reset() + >>> P.corners + (A, B, C) + + Or one can make a permutation as a product of selected permutations + and apply them to an iterable directly: + + >>> P10 = G.make_perm([0, 1]) + >>> P10('ABC') + ['C', 'A', 'B'] + + See Also + ======== + + sympy.combinatorics.polyhedron.Polyhedron, + sympy.combinatorics.permutations.Permutation + + References + ========== + + .. [1] Holt, D., Eick, B., O'Brien, E. + "Handbook of Computational Group Theory" + + .. [2] Seress, A. + "Permutation Group Algorithms" + + .. [3] https://en.wikipedia.org/wiki/Schreier_vector + + .. [4] https://en.wikipedia.org/wiki/Nielsen_transformation#Product_replacement_algorithm + + .. [5] Frank Celler, Charles R.Leedham-Green, Scott H.Murray, + Alice C.Niemeyer, and E.A.O'Brien. "Generating Random + Elements of a Finite Group" + + .. [6] https://en.wikipedia.org/wiki/Block_%28permutation_group_theory%29 + + .. [7] https://algorithmist.com/wiki/Union_find + + .. [8] https://en.wikipedia.org/wiki/Multiply_transitive_group#Multiply_transitive_groups + + .. [9] https://en.wikipedia.org/wiki/Center_%28group_theory%29 + + .. [10] https://en.wikipedia.org/wiki/Centralizer_and_normalizer + + .. [11] https://groupprops.subwiki.org/wiki/Derived_subgroup + + .. [12] https://en.wikipedia.org/wiki/Nilpotent_group + + .. [13] https://www.math.colostate.edu/~hulpke/CGT/cgtnotes.pdf + + .. [14] https://docs.gap-system.org/doc/ref/manual.pdf + + """ + is_group = True + + def __new__(cls, *args, dups=True, **kwargs): + """The default constructor. Accepts Cycle and Permutation forms. + Removes duplicates unless ``dups`` keyword is ``False``. + """ + if not args: + args = [Permutation()] + else: + args = list(args[0] if is_sequence(args[0]) else args) + if not args: + args = [Permutation()] + if any(isinstance(a, Cycle) for a in args): + args = [Permutation(a) for a in args] + if has_variety(a.size for a in args): + degree = kwargs.pop('degree', None) + if degree is None: + degree = max(a.size for a in args) + for i in range(len(args)): + if args[i].size != degree: + args[i] = Permutation(args[i], size=degree) + if dups: + args = list(uniq([_af_new(list(a)) for a in args])) + if len(args) > 1: + args = [g for g in args if not g.is_identity] + return Basic.__new__(cls, *args, **kwargs) + + def __init__(self, *args, **kwargs): + self._generators = list(self.args) + self._order = None + self._elements = [] + self._center = None + self._is_abelian = None + self._is_transitive = None + self._is_sym = None + self._is_alt = None + self._is_primitive = None + self._is_nilpotent = None + self._is_solvable = None + self._is_trivial = None + self._transitivity_degree = None + self._max_div = None + self._is_perfect = None + self._is_cyclic = None + self._is_dihedral = None + self._r = len(self._generators) + self._degree = self._generators[0].size + + # these attributes are assigned after running schreier_sims + self._base = [] + self._strong_gens = [] + self._strong_gens_slp = [] + self._basic_orbits = [] + self._transversals = [] + self._transversal_slp = [] + + # these attributes are assigned after running _random_pr_init + self._random_gens = [] + + # finite presentation of the group as an instance of `FpGroup` + self._fp_presentation = None + + def __getitem__(self, i): + return self._generators[i] + + def __contains__(self, i): + """Return ``True`` if *i* is contained in PermutationGroup. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> p = Permutation(1, 2, 3) + >>> Permutation(3) in PermutationGroup(p) + True + + """ + if not isinstance(i, Permutation): + raise TypeError("A PermutationGroup contains only Permutations as " + "elements, not elements of type %s" % type(i)) + return self.contains(i) + + def __len__(self): + return len(self._generators) + + def equals(self, other): + """Return ``True`` if PermutationGroup generated by elements in the + group are same i.e they represent the same PermutationGroup. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> p = Permutation(0, 1, 2, 3, 4, 5) + >>> G = PermutationGroup([p, p**2]) + >>> H = PermutationGroup([p**2, p]) + >>> G.generators == H.generators + False + >>> G.equals(H) + True + + """ + if not isinstance(other, PermutationGroup): + return False + + set_self_gens = set(self.generators) + set_other_gens = set(other.generators) + + # before reaching the general case there are also certain + # optimisation and obvious cases requiring less or no actual + # computation. + if set_self_gens == set_other_gens: + return True + + # in the most general case it will check that each generator of + # one group belongs to the other PermutationGroup and vice-versa + for gen1 in set_self_gens: + if not other.contains(gen1): + return False + for gen2 in set_other_gens: + if not self.contains(gen2): + return False + return True + + def __mul__(self, other): + """ + Return the direct product of two permutation groups as a permutation + group. + + Explanation + =========== + + This implementation realizes the direct product by shifting the index + set for the generators of the second group: so if we have ``G`` acting + on ``n1`` points and ``H`` acting on ``n2`` points, ``G*H`` acts on + ``n1 + n2`` points. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import CyclicGroup + >>> G = CyclicGroup(5) + >>> H = G*G + >>> H + PermutationGroup([ + (9)(0 1 2 3 4), + (5 6 7 8 9)]) + >>> H.order() + 25 + + """ + if isinstance(other, Permutation): + return Coset(other, self, dir='+') + gens1 = [perm._array_form for perm in self.generators] + gens2 = [perm._array_form for perm in other.generators] + n1 = self._degree + n2 = other._degree + start = list(range(n1)) + end = list(range(n1, n1 + n2)) + for i in range(len(gens2)): + gens2[i] = [x + n1 for x in gens2[i]] + gens2 = [start + gen for gen in gens2] + gens1 = [gen + end for gen in gens1] + together = gens1 + gens2 + gens = [_af_new(x) for x in together] + return PermutationGroup(gens) + + def _random_pr_init(self, r, n, _random_prec_n=None): + r"""Initialize random generators for the product replacement algorithm. + + Explanation + =========== + + The implementation uses a modification of the original product + replacement algorithm due to Leedham-Green, as described in [1], + pp. 69-71; also, see [2], pp. 27-29 for a detailed theoretical + analysis of the original product replacement algorithm, and [4]. + + The product replacement algorithm is used for producing random, + uniformly distributed elements of a group `G` with a set of generators + `S`. For the initialization ``_random_pr_init``, a list ``R`` of + `\max\{r, |S|\}` group generators is created as the attribute + ``G._random_gens``, repeating elements of `S` if necessary, and the + identity element of `G` is appended to ``R`` - we shall refer to this + last element as the accumulator. Then the function ``random_pr()`` + is called ``n`` times, randomizing the list ``R`` while preserving + the generation of `G` by ``R``. The function ``random_pr()`` itself + takes two random elements ``g, h`` among all elements of ``R`` but + the accumulator and replaces ``g`` with a randomly chosen element + from `\{gh, g(~h), hg, (~h)g\}`. Then the accumulator is multiplied + by whatever ``g`` was replaced by. The new value of the accumulator is + then returned by ``random_pr()``. + + The elements returned will eventually (for ``n`` large enough) become + uniformly distributed across `G` ([5]). For practical purposes however, + the values ``n = 50, r = 11`` are suggested in [1]. + + Notes + ===== + + THIS FUNCTION HAS SIDE EFFECTS: it changes the attribute + self._random_gens + + See Also + ======== + + random_pr + + """ + deg = self.degree + random_gens = [x._array_form for x in self.generators] + k = len(random_gens) + if k < r: + for i in range(k, r): + random_gens.append(random_gens[i - k]) + acc = list(range(deg)) + random_gens.append(acc) + self._random_gens = random_gens + + # handle randomized input for testing purposes + if _random_prec_n is None: + for i in range(n): + self.random_pr() + else: + for i in range(n): + self.random_pr(_random_prec=_random_prec_n[i]) + + def _union_find_merge(self, first, second, ranks, parents, not_rep): + """Merges two classes in a union-find data structure. + + Explanation + =========== + + Used in the implementation of Atkinson's algorithm as suggested in [1], + pp. 83-87. The class merging process uses union by rank as an + optimization. ([7]) + + Notes + ===== + + THIS FUNCTION HAS SIDE EFFECTS: the list of class representatives, + ``parents``, the list of class sizes, ``ranks``, and the list of + elements that are not representatives, ``not_rep``, are changed due to + class merging. + + See Also + ======== + + minimal_block, _union_find_rep + + References + ========== + + .. [1] Holt, D., Eick, B., O'Brien, E. + "Handbook of computational group theory" + + .. [7] https://algorithmist.com/wiki/Union_find + + """ + rep_first = self._union_find_rep(first, parents) + rep_second = self._union_find_rep(second, parents) + if rep_first != rep_second: + # union by rank + if ranks[rep_first] >= ranks[rep_second]: + new_1, new_2 = rep_first, rep_second + else: + new_1, new_2 = rep_second, rep_first + total_rank = ranks[new_1] + ranks[new_2] + if total_rank > self.max_div: + return -1 + parents[new_2] = new_1 + ranks[new_1] = total_rank + not_rep.append(new_2) + return 1 + return 0 + + def _union_find_rep(self, num, parents): + """Find representative of a class in a union-find data structure. + + Explanation + =========== + + Used in the implementation of Atkinson's algorithm as suggested in [1], + pp. 83-87. After the representative of the class to which ``num`` + belongs is found, path compression is performed as an optimization + ([7]). + + Notes + ===== + + THIS FUNCTION HAS SIDE EFFECTS: the list of class representatives, + ``parents``, is altered due to path compression. + + See Also + ======== + + minimal_block, _union_find_merge + + References + ========== + + .. [1] Holt, D., Eick, B., O'Brien, E. + "Handbook of computational group theory" + + .. [7] https://algorithmist.com/wiki/Union_find + + """ + rep, parent = num, parents[num] + while parent != rep: + rep = parent + parent = parents[rep] + # path compression + temp, parent = num, parents[num] + while parent != rep: + parents[temp] = rep + temp = parent + parent = parents[temp] + return rep + + @property + def base(self): + r"""Return a base from the Schreier-Sims algorithm. + + Explanation + =========== + + For a permutation group `G`, a base is a sequence of points + `B = (b_1, b_2, \dots, b_k)` such that no element of `G` apart + from the identity fixes all the points in `B`. The concepts of + a base and strong generating set and their applications are + discussed in depth in [1], pp. 87-89 and [2], pp. 55-57. + + An alternative way to think of `B` is that it gives the + indices of the stabilizer cosets that contain more than the + identity permutation. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> G = PermutationGroup([Permutation(0, 1, 3)(2, 4)]) + >>> G.base + [0, 2] + + See Also + ======== + + strong_gens, basic_transversals, basic_orbits, basic_stabilizers + + """ + if self._base == []: + self.schreier_sims() + return self._base + + def baseswap(self, base, strong_gens, pos, randomized=False, + transversals=None, basic_orbits=None, strong_gens_distr=None): + r"""Swap two consecutive base points in base and strong generating set. + + Explanation + =========== + + If a base for a group `G` is given by `(b_1, b_2, \dots, b_k)`, this + function returns a base `(b_1, b_2, \dots, b_{i+1}, b_i, \dots, b_k)`, + where `i` is given by ``pos``, and a strong generating set relative + to that base. The original base and strong generating set are not + modified. + + The randomized version (default) is of Las Vegas type. + + Parameters + ========== + + base, strong_gens + The base and strong generating set. + pos + The position at which swapping is performed. + randomized + A switch between randomized and deterministic version. + transversals + The transversals for the basic orbits, if known. + basic_orbits + The basic orbits, if known. + strong_gens_distr + The strong generators distributed by basic stabilizers, if known. + + Returns + ======= + + (base, strong_gens) + ``base`` is the new base, and ``strong_gens`` is a generating set + relative to it. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import SymmetricGroup + >>> from sympy.combinatorics.testutil import _verify_bsgs + >>> from sympy.combinatorics.perm_groups import PermutationGroup + >>> S = SymmetricGroup(4) + >>> S.schreier_sims() + >>> S.base + [0, 1, 2] + >>> base, gens = S.baseswap(S.base, S.strong_gens, 1, randomized=False) + >>> base, gens + ([0, 2, 1], + [(0 1 2 3), (3)(0 1), (1 3 2), + (2 3), (1 3)]) + + check that base, gens is a BSGS + + >>> S1 = PermutationGroup(gens) + >>> _verify_bsgs(S1, base, gens) + True + + See Also + ======== + + schreier_sims + + Notes + ===== + + The deterministic version of the algorithm is discussed in + [1], pp. 102-103; the randomized version is discussed in [1], p.103, and + [2], p.98. It is of Las Vegas type. + Notice that [1] contains a mistake in the pseudocode and + discussion of BASESWAP: on line 3 of the pseudocode, + `|\beta_{i+1}^{\left\langle T\right\rangle}|` should be replaced by + `|\beta_{i}^{\left\langle T\right\rangle}|`, and the same for the + discussion of the algorithm. + + """ + # construct the basic orbits, generators for the stabilizer chain + # and transversal elements from whatever was provided + transversals, basic_orbits, strong_gens_distr = \ + _handle_precomputed_bsgs(base, strong_gens, transversals, + basic_orbits, strong_gens_distr) + base_len = len(base) + degree = self.degree + # size of orbit of base[pos] under the stabilizer we seek to insert + # in the stabilizer chain at position pos + 1 + size = len(basic_orbits[pos])*len(basic_orbits[pos + 1]) \ + //len(_orbit(degree, strong_gens_distr[pos], base[pos + 1])) + # initialize the wanted stabilizer by a subgroup + if pos + 2 > base_len - 1: + T = [] + else: + T = strong_gens_distr[pos + 2][:] + # randomized version + if randomized is True: + stab_pos = PermutationGroup(strong_gens_distr[pos]) + schreier_vector = stab_pos.schreier_vector(base[pos + 1]) + # add random elements of the stabilizer until they generate it + while len(_orbit(degree, T, base[pos])) != size: + new = stab_pos.random_stab(base[pos + 1], + schreier_vector=schreier_vector) + T.append(new) + # deterministic version + else: + Gamma = set(basic_orbits[pos]) + Gamma.remove(base[pos]) + if base[pos + 1] in Gamma: + Gamma.remove(base[pos + 1]) + # add elements of the stabilizer until they generate it by + # ruling out member of the basic orbit of base[pos] along the way + while len(_orbit(degree, T, base[pos])) != size: + gamma = next(iter(Gamma)) + x = transversals[pos][gamma] + temp = x._array_form.index(base[pos + 1]) # (~x)(base[pos + 1]) + if temp not in basic_orbits[pos + 1]: + Gamma = Gamma - _orbit(degree, T, gamma) + else: + y = transversals[pos + 1][temp] + el = rmul(x, y) + if el(base[pos]) not in _orbit(degree, T, base[pos]): + T.append(el) + Gamma = Gamma - _orbit(degree, T, base[pos]) + # build the new base and strong generating set + strong_gens_new_distr = strong_gens_distr[:] + strong_gens_new_distr[pos + 1] = T + base_new = base[:] + base_new[pos], base_new[pos + 1] = base_new[pos + 1], base_new[pos] + strong_gens_new = _strong_gens_from_distr(strong_gens_new_distr) + for gen in T: + if gen not in strong_gens_new: + strong_gens_new.append(gen) + return base_new, strong_gens_new + + @property + def basic_orbits(self): + r""" + Return the basic orbits relative to a base and strong generating set. + + Explanation + =========== + + If `(b_1, b_2, \dots, b_k)` is a base for a group `G`, and + `G^{(i)} = G_{b_1, b_2, \dots, b_{i-1}}` is the ``i``-th basic stabilizer + (so that `G^{(1)} = G`), the ``i``-th basic orbit relative to this base + is the orbit of `b_i` under `G^{(i)}`. See [1], pp. 87-89 for more + information. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import SymmetricGroup + >>> S = SymmetricGroup(4) + >>> S.basic_orbits + [[0, 1, 2, 3], [1, 2, 3], [2, 3]] + + See Also + ======== + + base, strong_gens, basic_transversals, basic_stabilizers + + """ + if self._basic_orbits == []: + self.schreier_sims() + return self._basic_orbits + + @property + def basic_stabilizers(self): + r""" + Return a chain of stabilizers relative to a base and strong generating + set. + + Explanation + =========== + + The ``i``-th basic stabilizer `G^{(i)}` relative to a base + `(b_1, b_2, \dots, b_k)` is `G_{b_1, b_2, \dots, b_{i-1}}`. For more + information, see [1], pp. 87-89. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import AlternatingGroup + >>> A = AlternatingGroup(4) + >>> A.schreier_sims() + >>> A.base + [0, 1] + >>> for g in A.basic_stabilizers: + ... print(g) + ... + PermutationGroup([ + (3)(0 1 2), + (1 2 3)]) + PermutationGroup([ + (1 2 3)]) + + See Also + ======== + + base, strong_gens, basic_orbits, basic_transversals + + """ + + if self._transversals == []: + self.schreier_sims() + strong_gens = self._strong_gens + base = self._base + if not base: # e.g. if self is trivial + return [] + strong_gens_distr = _distribute_gens_by_base(base, strong_gens) + basic_stabilizers = [] + for gens in strong_gens_distr: + basic_stabilizers.append(PermutationGroup(gens)) + return basic_stabilizers + + @property + def basic_transversals(self): + """ + Return basic transversals relative to a base and strong generating set. + + Explanation + =========== + + The basic transversals are transversals of the basic orbits. They + are provided as a list of dictionaries, each dictionary having + keys - the elements of one of the basic orbits, and values - the + corresponding transversal elements. See [1], pp. 87-89 for more + information. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import AlternatingGroup + >>> A = AlternatingGroup(4) + >>> A.basic_transversals + [{0: (3), 1: (3)(0 1 2), 2: (3)(0 2 1), 3: (0 3 1)}, {1: (3), 2: (1 2 3), 3: (1 3 2)}] + + See Also + ======== + + strong_gens, base, basic_orbits, basic_stabilizers + + """ + + if self._transversals == []: + self.schreier_sims() + return self._transversals + + def composition_series(self): + r""" + Return the composition series for a group as a list + of permutation groups. + + Explanation + =========== + + The composition series for a group `G` is defined as a + subnormal series `G = H_0 > H_1 > H_2 \ldots` A composition + series is a subnormal series such that each factor group + `H(i+1) / H(i)` is simple. + A subnormal series is a composition series only if it is of + maximum length. + + The algorithm works as follows: + Starting with the derived series the idea is to fill + the gap between `G = der[i]` and `H = der[i+1]` for each + `i` independently. Since, all subgroups of the abelian group + `G/H` are normal so, first step is to take the generators + `g` of `G` and add them to generators of `H` one by one. + + The factor groups formed are not simple in general. Each + group is obtained from the previous one by adding one + generator `g`, if the previous group is denoted by `H` + then the next group `K` is generated by `g` and `H`. + The factor group `K/H` is cyclic and it's order is + `K.order()//G.order()`. The series is then extended between + `K` and `H` by groups generated by powers of `g` and `H`. + The series formed is then prepended to the already existing + series. + + Examples + ======== + >>> from sympy.combinatorics.named_groups import SymmetricGroup + >>> from sympy.combinatorics.named_groups import CyclicGroup + >>> S = SymmetricGroup(12) + >>> G = S.sylow_subgroup(2) + >>> C = G.composition_series() + >>> [H.order() for H in C] + [1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1] + >>> G = S.sylow_subgroup(3) + >>> C = G.composition_series() + >>> [H.order() for H in C] + [243, 81, 27, 9, 3, 1] + >>> G = CyclicGroup(12) + >>> C = G.composition_series() + >>> [H.order() for H in C] + [12, 6, 3, 1] + + """ + der = self.derived_series() + if not all(g.is_identity for g in der[-1].generators): + raise NotImplementedError('Group should be solvable') + series = [] + + for i in range(len(der)-1): + H = der[i+1] + up_seg = [] + for g in der[i].generators: + K = PermutationGroup([g] + H.generators) + order = K.order() // H.order() + down_seg = [] + for p, e in factorint(order).items(): + for _ in range(e): + down_seg.append(PermutationGroup([g] + H.generators)) + g = g**p + up_seg = down_seg + up_seg + H = K + up_seg[0] = der[i] + series.extend(up_seg) + series.append(der[-1]) + return series + + def coset_transversal(self, H): + """Return a transversal of the right cosets of self by its subgroup H + using the second method described in [1], Subsection 4.6.7 + + """ + + if not H.is_subgroup(self): + raise ValueError("The argument must be a subgroup") + + if H.order() == 1: + return self.elements + + self._schreier_sims(base=H.base) # make G.base an extension of H.base + + base = self.base + base_ordering = _base_ordering(base, self.degree) + identity = Permutation(self.degree - 1) + + transversals = self.basic_transversals[:] + # transversals is a list of dictionaries. Get rid of the keys + # so that it is a list of lists and sort each list in + # the increasing order of base[l]^x + for l, t in enumerate(transversals): + transversals[l] = sorted(t.values(), + key = lambda x: base_ordering[base[l]^x]) + + orbits = H.basic_orbits + h_stabs = H.basic_stabilizers + g_stabs = self.basic_stabilizers + + indices = [x.order()//y.order() for x, y in zip(g_stabs, h_stabs)] + + # T^(l) should be a right transversal of H^(l) in G^(l) for + # 1<=l<=len(base). While H^(l) is the trivial group, T^(l) + # contains all the elements of G^(l) so we might just as well + # start with l = len(h_stabs)-1 + if len(g_stabs) > len(h_stabs): + T = g_stabs[len(h_stabs)].elements + else: + T = [identity] + l = len(h_stabs)-1 + t_len = len(T) + while l > -1: + T_next = [] + for u in transversals[l]: + if u == identity: + continue + b = base_ordering[base[l]^u] + for t in T: + p = t*u + if all(base_ordering[h^p] >= b for h in orbits[l]): + T_next.append(p) + if t_len + len(T_next) == indices[l]: + break + if t_len + len(T_next) == indices[l]: + break + T += T_next + t_len += len(T_next) + l -= 1 + T.remove(identity) + T = [identity] + T + return T + + def _coset_representative(self, g, H): + """Return the representative of Hg from the transversal that + would be computed by ``self.coset_transversal(H)``. + + """ + if H.order() == 1: + return g + # The base of self must be an extension of H.base. + if not(self.base[:len(H.base)] == H.base): + self._schreier_sims(base=H.base) + orbits = H.basic_orbits[:] + h_transversals = [list(_.values()) for _ in H.basic_transversals] + transversals = [list(_.values()) for _ in self.basic_transversals] + base = self.base + base_ordering = _base_ordering(base, self.degree) + def step(l, x): + gamma = min(orbits[l], key = lambda y: base_ordering[y^x]) + i = [base[l]^h for h in h_transversals[l]].index(gamma) + x = h_transversals[l][i]*x + if l < len(orbits)-1: + for u in transversals[l]: + if base[l]^u == base[l]^x: + break + x = step(l+1, x*u**-1)*u + return x + return step(0, g) + + def coset_table(self, H): + """Return the standardised (right) coset table of self in H as + a list of lists. + """ + # Maybe this should be made to return an instance of CosetTable + # from fp_groups.py but the class would need to be changed first + # to be compatible with PermutationGroups + + if not H.is_subgroup(self): + raise ValueError("The argument must be a subgroup") + T = self.coset_transversal(H) + n = len(T) + + A = list(chain.from_iterable((gen, gen**-1) + for gen in self.generators)) + + table = [] + for i in range(n): + row = [self._coset_representative(T[i]*x, H) for x in A] + row = [T.index(r) for r in row] + table.append(row) + + # standardize (this is the same as the algorithm used in coset_table) + # If CosetTable is made compatible with PermutationGroups, this + # should be replaced by table.standardize() + A = range(len(A)) + gamma = 1 + for alpha, a in product(range(n), A): + beta = table[alpha][a] + if beta >= gamma: + if beta > gamma: + for x in A: + z = table[gamma][x] + table[gamma][x] = table[beta][x] + table[beta][x] = z + for i in range(n): + if table[i][x] == beta: + table[i][x] = gamma + elif table[i][x] == gamma: + table[i][x] = beta + gamma += 1 + if gamma >= n-1: + return table + + def center(self): + r""" + Return the center of a permutation group. + + Explanation + =========== + + The center for a group `G` is defined as + `Z(G) = \{z\in G | \forall g\in G, zg = gz \}`, + the set of elements of `G` that commute with all elements of `G`. + It is equal to the centralizer of `G` inside `G`, and is naturally a + subgroup of `G` ([9]). + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import DihedralGroup + >>> D = DihedralGroup(4) + >>> G = D.center() + >>> G.order() + 2 + + See Also + ======== + + centralizer + + Notes + ===== + + This is a naive implementation that is a straightforward application + of ``.centralizer()`` + + """ + if not self._center: + self._center = self.centralizer(self) + return self._center + + def centralizer(self, other): + r""" + Return the centralizer of a group/set/element. + + Explanation + =========== + + The centralizer of a set of permutations ``S`` inside + a group ``G`` is the set of elements of ``G`` that commute with all + elements of ``S``:: + + `C_G(S) = \{ g \in G | gs = sg \forall s \in S\}` ([10]) + + Usually, ``S`` is a subset of ``G``, but if ``G`` is a proper subgroup of + the full symmetric group, we allow for ``S`` to have elements outside + ``G``. + + It is naturally a subgroup of ``G``; the centralizer of a permutation + group is equal to the centralizer of any set of generators for that + group, since any element commuting with the generators commutes with + any product of the generators. + + Parameters + ========== + + other + a permutation group/list of permutations/single permutation + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import (SymmetricGroup, + ... CyclicGroup) + >>> S = SymmetricGroup(6) + >>> C = CyclicGroup(6) + >>> H = S.centralizer(C) + >>> H.is_subgroup(C) + True + + See Also + ======== + + subgroup_search + + Notes + ===== + + The implementation is an application of ``.subgroup_search()`` with + tests using a specific base for the group ``G``. + + """ + if hasattr(other, 'generators'): + if other.is_trivial or self.is_trivial: + return self + degree = self.degree + identity = _af_new(list(range(degree))) + orbits = other.orbits() + num_orbits = len(orbits) + orbits.sort(key=lambda x: -len(x)) + long_base = [] + orbit_reps = [None]*num_orbits + orbit_reps_indices = [None]*num_orbits + orbit_descr = [None]*degree + for i in range(num_orbits): + orbit = list(orbits[i]) + orbit_reps[i] = orbit[0] + orbit_reps_indices[i] = len(long_base) + for point in orbit: + orbit_descr[point] = i + long_base = long_base + orbit + base, strong_gens = self.schreier_sims_incremental(base=long_base) + strong_gens_distr = _distribute_gens_by_base(base, strong_gens) + i = 0 + for i in range(len(base)): + if strong_gens_distr[i] == [identity]: + break + base = base[:i] + base_len = i + for j in range(num_orbits): + if base[base_len - 1] in orbits[j]: + break + rel_orbits = orbits[: j + 1] + num_rel_orbits = len(rel_orbits) + transversals = [None]*num_rel_orbits + for j in range(num_rel_orbits): + rep = orbit_reps[j] + transversals[j] = dict( + other.orbit_transversal(rep, pairs=True)) + trivial_test = lambda x: True + tests = [None]*base_len + for l in range(base_len): + if base[l] in orbit_reps: + tests[l] = trivial_test + else: + def test(computed_words, l=l): + g = computed_words[l] + rep_orb_index = orbit_descr[base[l]] + rep = orbit_reps[rep_orb_index] + im = g._array_form[base[l]] + im_rep = g._array_form[rep] + tr_el = transversals[rep_orb_index][base[l]] + # using the definition of transversal, + # base[l]^g = rep^(tr_el*g); + # if g belongs to the centralizer, then + # base[l]^g = (rep^g)^tr_el + return im == tr_el._array_form[im_rep] + tests[l] = test + + def prop(g): + return [rmul(g, gen) for gen in other.generators] == \ + [rmul(gen, g) for gen in other.generators] + return self.subgroup_search(prop, base=base, + strong_gens=strong_gens, tests=tests) + elif hasattr(other, '__getitem__'): + gens = list(other) + return self.centralizer(PermutationGroup(gens)) + elif hasattr(other, 'array_form'): + return self.centralizer(PermutationGroup([other])) + + def commutator(self, G, H): + """ + Return the commutator of two subgroups. + + Explanation + =========== + + For a permutation group ``K`` and subgroups ``G``, ``H``, the + commutator of ``G`` and ``H`` is defined as the group generated + by all the commutators `[g, h] = hgh^{-1}g^{-1}` for ``g`` in ``G`` and + ``h`` in ``H``. It is naturally a subgroup of ``K`` ([1], p.27). + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import (SymmetricGroup, + ... AlternatingGroup) + >>> S = SymmetricGroup(5) + >>> A = AlternatingGroup(5) + >>> G = S.commutator(S, A) + >>> G.is_subgroup(A) + True + + See Also + ======== + + derived_subgroup + + Notes + ===== + + The commutator of two subgroups `H, G` is equal to the normal closure + of the commutators of all the generators, i.e. `hgh^{-1}g^{-1}` for `h` + a generator of `H` and `g` a generator of `G` ([1], p.28) + + """ + ggens = G.generators + hgens = H.generators + commutators = [] + for ggen in ggens: + for hgen in hgens: + commutator = rmul(hgen, ggen, ~hgen, ~ggen) + if commutator not in commutators: + commutators.append(commutator) + res = self.normal_closure(commutators) + return res + + def coset_factor(self, g, factor_index=False): + """Return ``G``'s (self's) coset factorization of ``g`` + + Explanation + =========== + + If ``g`` is an element of ``G`` then it can be written as the product + of permutations drawn from the Schreier-Sims coset decomposition, + + The permutations returned in ``f`` are those for which + the product gives ``g``: ``g = f[n]*...f[1]*f[0]`` where ``n = len(B)`` + and ``B = G.base``. f[i] is one of the permutations in + ``self._basic_orbits[i]``. + + If factor_index==True, + returns a tuple ``[b[0],..,b[n]]``, where ``b[i]`` + belongs to ``self._basic_orbits[i]`` + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> a = Permutation(0, 1, 3, 7, 6, 4)(2, 5) + >>> b = Permutation(0, 1, 3, 2)(4, 5, 7, 6) + >>> G = PermutationGroup([a, b]) + + Define g: + + >>> g = Permutation(7)(1, 2, 4)(3, 6, 5) + + Confirm that it is an element of G: + + >>> G.contains(g) + True + + Thus, it can be written as a product of factors (up to + 3) drawn from u. See below that a factor from u1 and u2 + and the Identity permutation have been used: + + >>> f = G.coset_factor(g) + >>> f[2]*f[1]*f[0] == g + True + >>> f1 = G.coset_factor(g, True); f1 + [0, 4, 4] + >>> tr = G.basic_transversals + >>> f[0] == tr[0][f1[0]] + True + + If g is not an element of G then [] is returned: + + >>> c = Permutation(5, 6, 7) + >>> G.coset_factor(c) + [] + + See Also + ======== + + sympy.combinatorics.util._strip + + """ + if isinstance(g, (Cycle, Permutation)): + g = g.list() + if len(g) != self._degree: + # this could either adjust the size or return [] immediately + # but we don't choose between the two and just signal a possible + # error + raise ValueError('g should be the same size as permutations of G') + I = list(range(self._degree)) + basic_orbits = self.basic_orbits + transversals = self._transversals + factors = [] + base = self.base + h = g + for i in range(len(base)): + beta = h[base[i]] + if beta == base[i]: + factors.append(beta) + continue + if beta not in basic_orbits[i]: + return [] + u = transversals[i][beta]._array_form + h = _af_rmul(_af_invert(u), h) + factors.append(beta) + if h != I: + return [] + if factor_index: + return factors + tr = self.basic_transversals + factors = [tr[i][factors[i]] for i in range(len(base))] + return factors + + def generator_product(self, g, original=False): + r''' + Return a list of strong generators `[s1, \dots, sn]` + s.t `g = sn \times \dots \times s1`. If ``original=True``, make the + list contain only the original group generators + + ''' + product = [] + if g.is_identity: + return [] + if g in self.strong_gens: + if not original or g in self.generators: + return [g] + else: + slp = self._strong_gens_slp[g] + for s in slp: + product.extend(self.generator_product(s, original=True)) + return product + elif g**-1 in self.strong_gens: + g = g**-1 + if not original or g in self.generators: + return [g**-1] + else: + slp = self._strong_gens_slp[g] + for s in slp: + product.extend(self.generator_product(s, original=True)) + l = len(product) + product = [product[l-i-1]**-1 for i in range(l)] + return product + + f = self.coset_factor(g, True) + for i, j in enumerate(f): + slp = self._transversal_slp[i][j] + for s in slp: + if not original: + product.append(self.strong_gens[s]) + else: + s = self.strong_gens[s] + product.extend(self.generator_product(s, original=True)) + return product + + def coset_rank(self, g): + """rank using Schreier-Sims representation. + + Explanation + =========== + + The coset rank of ``g`` is the ordering number in which + it appears in the lexicographic listing according to the + coset decomposition + + The ordering is the same as in G.generate(method='coset'). + If ``g`` does not belong to the group it returns None. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> a = Permutation(0, 1, 3, 7, 6, 4)(2, 5) + >>> b = Permutation(0, 1, 3, 2)(4, 5, 7, 6) + >>> G = PermutationGroup([a, b]) + >>> c = Permutation(7)(2, 4)(3, 5) + >>> G.coset_rank(c) + 16 + >>> G.coset_unrank(16) + (7)(2 4)(3 5) + + See Also + ======== + + coset_factor + + """ + factors = self.coset_factor(g, True) + if not factors: + return None + rank = 0 + b = 1 + transversals = self._transversals + base = self._base + basic_orbits = self._basic_orbits + for i in range(len(base)): + k = factors[i] + j = basic_orbits[i].index(k) + rank += b*j + b = b*len(transversals[i]) + return rank + + def coset_unrank(self, rank, af=False): + """unrank using Schreier-Sims representation + + coset_unrank is the inverse operation of coset_rank + if 0 <= rank < order; otherwise it returns None. + + """ + if rank < 0 or rank >= self.order(): + return None + base = self.base + transversals = self.basic_transversals + basic_orbits = self.basic_orbits + m = len(base) + v = [0]*m + for i in range(m): + rank, c = divmod(rank, len(transversals[i])) + v[i] = basic_orbits[i][c] + a = [transversals[i][v[i]]._array_form for i in range(m)] + h = _af_rmuln(*a) + if af: + return h + else: + return _af_new(h) + + @property + def degree(self): + """Returns the size of the permutations in the group. + + Explanation + =========== + + The number of permutations comprising the group is given by + ``len(group)``; the number of permutations that can be generated + by the group is given by ``group.order()``. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> a = Permutation([1, 0, 2]) + >>> G = PermutationGroup([a]) + >>> G.degree + 3 + >>> len(G) + 1 + >>> G.order() + 2 + >>> list(G.generate()) + [(2), (2)(0 1)] + + See Also + ======== + + order + """ + return self._degree + + @property + def identity(self): + ''' + Return the identity element of the permutation group. + + ''' + return _af_new(list(range(self.degree))) + + @property + def elements(self): + """Returns all the elements of the permutation group as a list + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> p = PermutationGroup(Permutation(1, 3), Permutation(1, 2)) + >>> p.elements + [(3), (3)(1 2), (1 3), (2 3), (1 2 3), (1 3 2)] + + """ + if not self._elements: + self._elements = list(self.generate()) + + return self._elements + + def derived_series(self): + r"""Return the derived series for the group. + + Explanation + =========== + + The derived series for a group `G` is defined as + `G = G_0 > G_1 > G_2 > \ldots` where `G_i = [G_{i-1}, G_{i-1}]`, + i.e. `G_i` is the derived subgroup of `G_{i-1}`, for + `i\in\mathbb{N}`. When we have `G_k = G_{k-1}` for some + `k\in\mathbb{N}`, the series terminates. + + Returns + ======= + + A list of permutation groups containing the members of the derived + series in the order `G = G_0, G_1, G_2, \ldots`. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import (SymmetricGroup, + ... AlternatingGroup, DihedralGroup) + >>> A = AlternatingGroup(5) + >>> len(A.derived_series()) + 1 + >>> S = SymmetricGroup(4) + >>> len(S.derived_series()) + 4 + >>> S.derived_series()[1].is_subgroup(AlternatingGroup(4)) + True + >>> S.derived_series()[2].is_subgroup(DihedralGroup(2)) + True + + See Also + ======== + + derived_subgroup + + """ + res = [self] + current = self + nxt = self.derived_subgroup() + while not current.is_subgroup(nxt): + res.append(nxt) + current = nxt + nxt = nxt.derived_subgroup() + return res + + def derived_subgroup(self): + r"""Compute the derived subgroup. + + Explanation + =========== + + The derived subgroup, or commutator subgroup is the subgroup generated + by all commutators `[g, h] = hgh^{-1}g^{-1}` for `g, h\in G` ; it is + equal to the normal closure of the set of commutators of the generators + ([1], p.28, [11]). + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> a = Permutation([1, 0, 2, 4, 3]) + >>> b = Permutation([0, 1, 3, 2, 4]) + >>> G = PermutationGroup([a, b]) + >>> C = G.derived_subgroup() + >>> list(C.generate(af=True)) + [[0, 1, 2, 3, 4], [0, 1, 3, 4, 2], [0, 1, 4, 2, 3]] + + See Also + ======== + + derived_series + + """ + r = self._r + gens = [p._array_form for p in self.generators] + set_commutators = set() + degree = self._degree + rng = list(range(degree)) + for i in range(r): + for j in range(r): + p1 = gens[i] + p2 = gens[j] + c = list(range(degree)) + for k in rng: + c[p2[p1[k]]] = p1[p2[k]] + ct = tuple(c) + if ct not in set_commutators: + set_commutators.add(ct) + cms = [_af_new(p) for p in set_commutators] + G2 = self.normal_closure(cms) + return G2 + + def generate(self, method="coset", af=False): + """Return iterator to generate the elements of the group. + + Explanation + =========== + + Iteration is done with one of these methods:: + + method='coset' using the Schreier-Sims coset representation + method='dimino' using the Dimino method + + If ``af = True`` it yields the array form of the permutations + + Examples + ======== + + >>> from sympy.combinatorics import PermutationGroup + >>> from sympy.combinatorics.polyhedron import tetrahedron + + The permutation group given in the tetrahedron object is also + true groups: + + >>> G = tetrahedron.pgroup + >>> G.is_group + True + + Also the group generated by the permutations in the tetrahedron + pgroup -- even the first two -- is a proper group: + + >>> H = PermutationGroup(G[0], G[1]) + >>> J = PermutationGroup(list(H.generate())); J + PermutationGroup([ + (0 1)(2 3), + (1 2 3), + (1 3 2), + (0 3 1), + (0 2 3), + (0 3)(1 2), + (0 1 3), + (3)(0 2 1), + (0 3 2), + (3)(0 1 2), + (0 2)(1 3)]) + >>> _.is_group + True + """ + if method == "coset": + return self.generate_schreier_sims(af) + elif method == "dimino": + return self.generate_dimino(af) + else: + raise NotImplementedError('No generation defined for %s' % method) + + def generate_dimino(self, af=False): + """Yield group elements using Dimino's algorithm. + + If ``af == True`` it yields the array form of the permutations. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> a = Permutation([0, 2, 1, 3]) + >>> b = Permutation([0, 2, 3, 1]) + >>> g = PermutationGroup([a, b]) + >>> list(g.generate_dimino(af=True)) + [[0, 1, 2, 3], [0, 2, 1, 3], [0, 2, 3, 1], + [0, 1, 3, 2], [0, 3, 2, 1], [0, 3, 1, 2]] + + References + ========== + + .. [1] The Implementation of Various Algorithms for Permutation Groups in + the Computer Algebra System: AXIOM, N.J. Doye, M.Sc. Thesis + + """ + idn = list(range(self.degree)) + order = 0 + element_list = [idn] + set_element_list = {tuple(idn)} + if af: + yield idn + else: + yield _af_new(idn) + gens = [p._array_form for p in self.generators] + + for i in range(len(gens)): + # D elements of the subgroup G_i generated by gens[:i] + D = element_list.copy() + N = [idn] + while N: + A = N + N = [] + for a in A: + for g in gens[:i + 1]: + ag = _af_rmul(a, g) + if tuple(ag) not in set_element_list: + # produce G_i*g + for d in D: + order += 1 + ap = _af_rmul(d, ag) + if af: + yield ap + else: + p = _af_new(ap) + yield p + element_list.append(ap) + set_element_list.add(tuple(ap)) + N.append(ap) + self._order = len(element_list) + + def generate_schreier_sims(self, af=False): + """Yield group elements using the Schreier-Sims representation + in coset_rank order + + If ``af = True`` it yields the array form of the permutations + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> a = Permutation([0, 2, 1, 3]) + >>> b = Permutation([0, 2, 3, 1]) + >>> g = PermutationGroup([a, b]) + >>> list(g.generate_schreier_sims(af=True)) + [[0, 1, 2, 3], [0, 2, 1, 3], [0, 3, 2, 1], + [0, 1, 3, 2], [0, 2, 3, 1], [0, 3, 1, 2]] + """ + + n = self._degree + u = self.basic_transversals + basic_orbits = self._basic_orbits + if len(u) == 0: + for x in self.generators: + if af: + yield x._array_form + else: + yield x + return + if len(u) == 1: + for i in basic_orbits[0]: + if af: + yield u[0][i]._array_form + else: + yield u[0][i] + return + + u = list(reversed(u)) + basic_orbits = basic_orbits[::-1] + # stg stack of group elements + stg = [list(range(n))] + posmax = [len(x) for x in u] + n1 = len(posmax) - 1 + pos = [0]*n1 + h = 0 + while 1: + # backtrack when finished iterating over coset + if pos[h] >= posmax[h]: + if h == 0: + return + pos[h] = 0 + h -= 1 + stg.pop() + continue + p = _af_rmul(u[h][basic_orbits[h][pos[h]]]._array_form, stg[-1]) + pos[h] += 1 + stg.append(p) + h += 1 + if h == n1: + if af: + for i in basic_orbits[-1]: + p = _af_rmul(u[-1][i]._array_form, stg[-1]) + yield p + else: + for i in basic_orbits[-1]: + p = _af_rmul(u[-1][i]._array_form, stg[-1]) + p1 = _af_new(p) + yield p1 + stg.pop() + h -= 1 + + @property + def generators(self): + """Returns the generators of the group. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> a = Permutation([0, 2, 1]) + >>> b = Permutation([1, 0, 2]) + >>> G = PermutationGroup([a, b]) + >>> G.generators + [(1 2), (2)(0 1)] + + """ + return self._generators + + def contains(self, g, strict=True): + """Test if permutation ``g`` belong to self, ``G``. + + Explanation + =========== + + If ``g`` is an element of ``G`` it can be written as a product + of factors drawn from the cosets of ``G``'s stabilizers. To see + if ``g`` is one of the actual generators defining the group use + ``G.has(g)``. + + If ``strict`` is not ``True``, ``g`` will be resized, if necessary, + to match the size of permutations in ``self``. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + + >>> a = Permutation(1, 2) + >>> b = Permutation(2, 3, 1) + >>> G = PermutationGroup(a, b, degree=5) + >>> G.contains(G[0]) # trivial check + True + >>> elem = Permutation([[2, 3]], size=5) + >>> G.contains(elem) + True + >>> G.contains(Permutation(4)(0, 1, 2, 3)) + False + + If strict is False, a permutation will be resized, if + necessary: + + >>> H = PermutationGroup(Permutation(5)) + >>> H.contains(Permutation(3)) + False + >>> H.contains(Permutation(3), strict=False) + True + + To test if a given permutation is present in the group: + + >>> elem in G.generators + False + >>> G.has(elem) + False + + See Also + ======== + + coset_factor, sympy.core.basic.Basic.has, __contains__ + + """ + if not isinstance(g, Permutation): + return False + if g.size != self.degree: + if strict: + return False + g = Permutation(g, size=self.degree) + if g in self.generators: + return True + return bool(self.coset_factor(g.array_form, True)) + + @property + def is_perfect(self): + """Return ``True`` if the group is perfect. + A group is perfect if it equals to its derived subgroup. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> a = Permutation(1,2,3)(4,5) + >>> b = Permutation(1,2,3,4,5) + >>> G = PermutationGroup([a, b]) + >>> G.is_perfect + False + + """ + if self._is_perfect is None: + self._is_perfect = self.equals(self.derived_subgroup()) + return self._is_perfect + + @property + def is_abelian(self): + """Test if the group is Abelian. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> a = Permutation([0, 2, 1]) + >>> b = Permutation([1, 0, 2]) + >>> G = PermutationGroup([a, b]) + >>> G.is_abelian + False + >>> a = Permutation([0, 2, 1]) + >>> G = PermutationGroup([a]) + >>> G.is_abelian + True + + """ + if self._is_abelian is not None: + return self._is_abelian + + self._is_abelian = True + gens = [p._array_form for p in self.generators] + for x in gens: + for y in gens: + if y <= x: + continue + if not _af_commutes_with(x, y): + self._is_abelian = False + return False + return True + + def abelian_invariants(self): + """ + Returns the abelian invariants for the given group. + Let ``G`` be a nontrivial finite abelian group. Then G is isomorphic to + the direct product of finitely many nontrivial cyclic groups of + prime-power order. + + Explanation + =========== + + The prime-powers that occur as the orders of the factors are uniquely + determined by G. More precisely, the primes that occur in the orders of the + factors in any such decomposition of ``G`` are exactly the primes that divide + ``|G|`` and for any such prime ``p``, if the orders of the factors that are + p-groups in one such decomposition of ``G`` are ``p^{t_1} >= p^{t_2} >= ... p^{t_r}``, + then the orders of the factors that are p-groups in any such decomposition of ``G`` + are ``p^{t_1} >= p^{t_2} >= ... p^{t_r}``. + + The uniquely determined integers ``p^{t_1} >= p^{t_2} >= ... p^{t_r}``, taken + for all primes that divide ``|G|`` are called the invariants of the nontrivial + group ``G`` as suggested in ([14], p. 542). + + Notes + ===== + + We adopt the convention that the invariants of a trivial group are []. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> a = Permutation([0, 2, 1]) + >>> b = Permutation([1, 0, 2]) + >>> G = PermutationGroup([a, b]) + >>> G.abelian_invariants() + [2] + >>> from sympy.combinatorics import CyclicGroup + >>> G = CyclicGroup(7) + >>> G.abelian_invariants() + [7] + + """ + if self.is_trivial: + return [] + gns = self.generators + inv = [] + G = self + H = G.derived_subgroup() + Hgens = H.generators + for p in primefactors(G.order()): + ranks = [] + while True: + pows = [] + for g in gns: + elm = g**p + if not H.contains(elm): + pows.append(elm) + K = PermutationGroup(Hgens + pows) if pows else H + r = G.order()//K.order() + G = K + gns = pows + if r == 1: + break + ranks.append(multiplicity(p, r)) + + if ranks: + pows = [1]*ranks[0] + for i in ranks: + for j in range(i): + pows[j] = pows[j]*p + inv.extend(pows) + inv.sort() + return inv + + def is_elementary(self, p): + """Return ``True`` if the group is elementary abelian. An elementary + abelian group is a finite abelian group, where every nontrivial + element has order `p`, where `p` is a prime. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> a = Permutation([0, 2, 1]) + >>> G = PermutationGroup([a]) + >>> G.is_elementary(2) + True + >>> a = Permutation([0, 2, 1, 3]) + >>> b = Permutation([3, 1, 2, 0]) + >>> G = PermutationGroup([a, b]) + >>> G.is_elementary(2) + True + >>> G.is_elementary(3) + False + + """ + return self.is_abelian and all(g.order() == p for g in self.generators) + + def _eval_is_alt_sym_naive(self, only_sym=False, only_alt=False): + """A naive test using the group order.""" + if only_sym and only_alt: + raise ValueError( + "Both {} and {} cannot be set to True" + .format(only_sym, only_alt)) + + n = self.degree + sym_order = _factorial(n) + order = self.order() + + if order == sym_order: + self._is_sym = True + self._is_alt = False + return not only_alt + + if 2*order == sym_order: + self._is_sym = False + self._is_alt = True + return not only_sym + + return False + + def _eval_is_alt_sym_monte_carlo(self, eps=0.05, perms=None): + """A test using monte-carlo algorithm. + + Parameters + ========== + + eps : float, optional + The criterion for the incorrect ``False`` return. + + perms : list[Permutation], optional + If explicitly given, it tests over the given candidates + for testing. + + If ``None``, it randomly computes ``N_eps`` and chooses + ``N_eps`` sample of the permutation from the group. + + See Also + ======== + + _check_cycles_alt_sym + """ + if perms is None: + n = self.degree + if n < 17: + c_n = 0.34 + else: + c_n = 0.57 + d_n = (c_n*log(2))/log(n) + N_eps = int(-log(eps)/d_n) + + perms = (self.random_pr() for i in range(N_eps)) + return self._eval_is_alt_sym_monte_carlo(perms=perms) + + for perm in perms: + if _check_cycles_alt_sym(perm): + return True + return False + + def is_alt_sym(self, eps=0.05, _random_prec=None): + r"""Monte Carlo test for the symmetric/alternating group for degrees + >= 8. + + Explanation + =========== + + More specifically, it is one-sided Monte Carlo with the + answer True (i.e., G is symmetric/alternating) guaranteed to be + correct, and the answer False being incorrect with probability eps. + + For degree < 8, the order of the group is checked so the test + is deterministic. + + Notes + ===== + + The algorithm itself uses some nontrivial results from group theory and + number theory: + 1) If a transitive group ``G`` of degree ``n`` contains an element + with a cycle of length ``n/2 < p < n-2`` for ``p`` a prime, ``G`` is the + symmetric or alternating group ([1], pp. 81-82) + 2) The proportion of elements in the symmetric/alternating group having + the property described in 1) is approximately `\log(2)/\log(n)` + ([1], p.82; [2], pp. 226-227). + The helper function ``_check_cycles_alt_sym`` is used to + go over the cycles in a permutation and look for ones satisfying 1). + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import DihedralGroup + >>> D = DihedralGroup(10) + >>> D.is_alt_sym() + False + + See Also + ======== + + _check_cycles_alt_sym + + """ + if _random_prec is not None: + N_eps = _random_prec['N_eps'] + perms= (_random_prec[i] for i in range(N_eps)) + return self._eval_is_alt_sym_monte_carlo(perms=perms) + + if self._is_sym or self._is_alt: + return True + if self._is_sym is False and self._is_alt is False: + return False + + n = self.degree + if n < 8: + return self._eval_is_alt_sym_naive() + elif self.is_transitive(): + return self._eval_is_alt_sym_monte_carlo(eps=eps) + + self._is_sym, self._is_alt = False, False + return False + + @property + def is_nilpotent(self): + """Test if the group is nilpotent. + + Explanation + =========== + + A group `G` is nilpotent if it has a central series of finite length. + Alternatively, `G` is nilpotent if its lower central series terminates + with the trivial group. Every nilpotent group is also solvable + ([1], p.29, [12]). + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import (SymmetricGroup, + ... CyclicGroup) + >>> C = CyclicGroup(6) + >>> C.is_nilpotent + True + >>> S = SymmetricGroup(5) + >>> S.is_nilpotent + False + + See Also + ======== + + lower_central_series, is_solvable + + """ + if self._is_nilpotent is None: + lcs = self.lower_central_series() + terminator = lcs[len(lcs) - 1] + gens = terminator.generators + degree = self.degree + identity = _af_new(list(range(degree))) + if all(g == identity for g in gens): + self._is_solvable = True + self._is_nilpotent = True + return True + else: + self._is_nilpotent = False + return False + else: + return self._is_nilpotent + + def is_normal(self, gr, strict=True): + """Test if ``G=self`` is a normal subgroup of ``gr``. + + Explanation + =========== + + G is normal in gr if + for each g2 in G, g1 in gr, ``g = g1*g2*g1**-1`` belongs to G + It is sufficient to check this for each g1 in gr.generators and + g2 in G.generators. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> a = Permutation([1, 2, 0]) + >>> b = Permutation([1, 0, 2]) + >>> G = PermutationGroup([a, b]) + >>> G1 = PermutationGroup([a, Permutation([2, 0, 1])]) + >>> G1.is_normal(G) + True + + """ + if not self.is_subgroup(gr, strict=strict): + return False + d_self = self.degree + d_gr = gr.degree + if self.is_trivial and (d_self == d_gr or not strict): + return True + if self._is_abelian: + return True + new_self = self.copy() + if not strict and d_self != d_gr: + if d_self < d_gr: + new_self = PermGroup(new_self.generators + [Permutation(d_gr - 1)]) + else: + gr = PermGroup(gr.generators + [Permutation(d_self - 1)]) + gens2 = [p._array_form for p in new_self.generators] + gens1 = [p._array_form for p in gr.generators] + for g1 in gens1: + for g2 in gens2: + p = _af_rmuln(g1, g2, _af_invert(g1)) + if not new_self.coset_factor(p, True): + return False + return True + + def is_primitive(self, randomized=True): + r"""Test if a group is primitive. + + Explanation + =========== + + A permutation group ``G`` acting on a set ``S`` is called primitive if + ``S`` contains no nontrivial block under the action of ``G`` + (a block is nontrivial if its cardinality is more than ``1``). + + Notes + ===== + + The algorithm is described in [1], p.83, and uses the function + minimal_block to search for blocks of the form `\{0, k\}` for ``k`` + ranging over representatives for the orbits of `G_0`, the stabilizer of + ``0``. This algorithm has complexity `O(n^2)` where ``n`` is the degree + of the group, and will perform badly if `G_0` is small. + + There are two implementations offered: one finds `G_0` + deterministically using the function ``stabilizer``, and the other + (default) produces random elements of `G_0` using ``random_stab``, + hoping that they generate a subgroup of `G_0` with not too many more + orbits than `G_0` (this is suggested in [1], p.83). Behavior is changed + by the ``randomized`` flag. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import DihedralGroup + >>> D = DihedralGroup(10) + >>> D.is_primitive() + False + + See Also + ======== + + minimal_block, random_stab + + """ + if self._is_primitive is not None: + return self._is_primitive + + if self.is_transitive() is False: + return False + + if randomized: + random_stab_gens = [] + v = self.schreier_vector(0) + for _ in range(len(self)): + random_stab_gens.append(self.random_stab(0, v)) + stab = PermutationGroup(random_stab_gens) + else: + stab = self.stabilizer(0) + orbits = stab.orbits() + for orb in orbits: + x = orb.pop() + if x != 0 and any(e != 0 for e in self.minimal_block([0, x])): + self._is_primitive = False + return False + self._is_primitive = True + return True + + def minimal_blocks(self, randomized=True): + ''' + For a transitive group, return the list of all minimal + block systems. If a group is intransitive, return `False`. + + Examples + ======== + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> from sympy.combinatorics.named_groups import DihedralGroup + >>> DihedralGroup(6).minimal_blocks() + [[0, 1, 0, 1, 0, 1], [0, 1, 2, 0, 1, 2]] + >>> G = PermutationGroup(Permutation(1,2,5)) + >>> G.minimal_blocks() + False + + See Also + ======== + + minimal_block, is_transitive, is_primitive + + ''' + def _number_blocks(blocks): + # number the blocks of a block system + # in order and return the number of + # blocks and the tuple with the + # reordering + n = len(blocks) + appeared = {} + m = 0 + b = [None]*n + for i in range(n): + if blocks[i] not in appeared: + appeared[blocks[i]] = m + b[i] = m + m += 1 + else: + b[i] = appeared[blocks[i]] + return tuple(b), m + + if not self.is_transitive(): + return False + blocks = [] + num_blocks = [] + rep_blocks = [] + if randomized: + random_stab_gens = [] + v = self.schreier_vector(0) + for i in range(len(self)): + random_stab_gens.append(self.random_stab(0, v)) + stab = PermutationGroup(random_stab_gens) + else: + stab = self.stabilizer(0) + orbits = stab.orbits() + for orb in orbits: + x = orb.pop() + if x != 0: + block = self.minimal_block([0, x]) + num_block, _ = _number_blocks(block) + # a representative block (containing 0) + rep = {j for j in range(self.degree) if num_block[j] == 0} + # check if the system is minimal with + # respect to the already discovere ones + minimal = True + blocks_remove_mask = [False] * len(blocks) + for i, r in enumerate(rep_blocks): + if len(r) > len(rep) and rep.issubset(r): + # i-th block system is not minimal + blocks_remove_mask[i] = True + elif len(r) < len(rep) and r.issubset(rep): + # the system being checked is not minimal + minimal = False + break + # remove non-minimal representative blocks + blocks = [b for i, b in enumerate(blocks) if not blocks_remove_mask[i]] + num_blocks = [n for i, n in enumerate(num_blocks) if not blocks_remove_mask[i]] + rep_blocks = [r for i, r in enumerate(rep_blocks) if not blocks_remove_mask[i]] + + if minimal and num_block not in num_blocks: + blocks.append(block) + num_blocks.append(num_block) + rep_blocks.append(rep) + return blocks + + @property + def is_solvable(self): + """Test if the group is solvable. + + ``G`` is solvable if its derived series terminates with the trivial + group ([1], p.29). + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import SymmetricGroup + >>> S = SymmetricGroup(3) + >>> S.is_solvable + True + + See Also + ======== + + is_nilpotent, derived_series + + """ + if self._is_solvable is None: + if self.order() % 2 != 0: + return True + ds = self.derived_series() + terminator = ds[len(ds) - 1] + gens = terminator.generators + degree = self.degree + identity = _af_new(list(range(degree))) + if all(g == identity for g in gens): + self._is_solvable = True + return True + else: + self._is_solvable = False + return False + else: + return self._is_solvable + + def is_subgroup(self, G, strict=True): + """Return ``True`` if all elements of ``self`` belong to ``G``. + + If ``strict`` is ``False`` then if ``self``'s degree is smaller + than ``G``'s, the elements will be resized to have the same degree. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> from sympy.combinatorics import SymmetricGroup, CyclicGroup + + Testing is strict by default: the degree of each group must be the + same: + + >>> p = Permutation(0, 1, 2, 3, 4, 5) + >>> G1 = PermutationGroup([Permutation(0, 1, 2), Permutation(0, 1)]) + >>> G2 = PermutationGroup([Permutation(0, 2), Permutation(0, 1, 2)]) + >>> G3 = PermutationGroup([p, p**2]) + >>> assert G1.order() == G2.order() == G3.order() == 6 + >>> G1.is_subgroup(G2) + True + >>> G1.is_subgroup(G3) + False + >>> G3.is_subgroup(PermutationGroup(G3[1])) + False + >>> G3.is_subgroup(PermutationGroup(G3[0])) + True + + To ignore the size, set ``strict`` to ``False``: + + >>> S3 = SymmetricGroup(3) + >>> S5 = SymmetricGroup(5) + >>> S3.is_subgroup(S5, strict=False) + True + >>> C7 = CyclicGroup(7) + >>> G = S5*C7 + >>> S5.is_subgroup(G, False) + True + >>> C7.is_subgroup(G, 0) + False + + """ + if isinstance(G, SymmetricPermutationGroup): + if self.degree != G.degree: + return False + return True + if not isinstance(G, PermutationGroup): + return False + if self == G or self.generators[0]==Permutation(): + return True + if G.order() % self.order() != 0: + return False + if self.degree == G.degree or \ + (self.degree < G.degree and not strict): + gens = self.generators + else: + return False + return all(G.contains(g, strict=strict) for g in gens) + + @property + def is_polycyclic(self): + """Return ``True`` if a group is polycyclic. A group is polycyclic if + it has a subnormal series with cyclic factors. For finite groups, + this is the same as if the group is solvable. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> a = Permutation([0, 2, 1, 3]) + >>> b = Permutation([2, 0, 1, 3]) + >>> G = PermutationGroup([a, b]) + >>> G.is_polycyclic + True + + """ + return self.is_solvable + + def is_transitive(self, strict=True): + """Test if the group is transitive. + + Explanation + =========== + + A group is transitive if it has a single orbit. + + If ``strict`` is ``False`` the group is transitive if it has + a single orbit of length different from 1. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> a = Permutation([0, 2, 1, 3]) + >>> b = Permutation([2, 0, 1, 3]) + >>> G1 = PermutationGroup([a, b]) + >>> G1.is_transitive() + False + >>> G1.is_transitive(strict=False) + True + >>> c = Permutation([2, 3, 0, 1]) + >>> G2 = PermutationGroup([a, c]) + >>> G2.is_transitive() + True + >>> d = Permutation([1, 0, 2, 3]) + >>> e = Permutation([0, 1, 3, 2]) + >>> G3 = PermutationGroup([d, e]) + >>> G3.is_transitive() or G3.is_transitive(strict=False) + False + + """ + if self._is_transitive: # strict or not, if True then True + return self._is_transitive + if strict: + if self._is_transitive is not None: # we only store strict=True + return self._is_transitive + + ans = len(self.orbit(0)) == self.degree + self._is_transitive = ans + return ans + + got_orb = False + for x in self.orbits(): + if len(x) > 1: + if got_orb: + return False + got_orb = True + return got_orb + + @property + def is_trivial(self): + """Test if the group is the trivial group. + + This is true if the group contains only the identity permutation. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> G = PermutationGroup([Permutation([0, 1, 2])]) + >>> G.is_trivial + True + + """ + if self._is_trivial is None: + self._is_trivial = len(self) == 1 and self[0].is_Identity + return self._is_trivial + + def lower_central_series(self): + r"""Return the lower central series for the group. + + The lower central series for a group `G` is the series + `G = G_0 > G_1 > G_2 > \ldots` where + `G_k = [G, G_{k-1}]`, i.e. every term after the first is equal to the + commutator of `G` and the previous term in `G1` ([1], p.29). + + Returns + ======= + + A list of permutation groups in the order `G = G_0, G_1, G_2, \ldots` + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import (AlternatingGroup, + ... DihedralGroup) + >>> A = AlternatingGroup(4) + >>> len(A.lower_central_series()) + 2 + >>> A.lower_central_series()[1].is_subgroup(DihedralGroup(2)) + True + + See Also + ======== + + commutator, derived_series + + """ + res = [self] + current = self + nxt = self.commutator(self, current) + while not current.is_subgroup(nxt): + res.append(nxt) + current = nxt + nxt = self.commutator(self, current) + return res + + @property + def max_div(self): + """Maximum proper divisor of the degree of a permutation group. + + Explanation + =========== + + Obviously, this is the degree divided by its minimal proper divisor + (larger than ``1``, if one exists). As it is guaranteed to be prime, + the ``sieve`` from ``sympy.ntheory`` is used. + This function is also used as an optimization tool for the functions + ``minimal_block`` and ``_union_find_merge``. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> G = PermutationGroup([Permutation([0, 2, 1, 3])]) + >>> G.max_div + 2 + + See Also + ======== + + minimal_block, _union_find_merge + + """ + if self._max_div is not None: + return self._max_div + n = self.degree + if n == 1: + return 1 + for x in sieve: + if n % x == 0: + d = n//x + self._max_div = d + return d + + def minimal_block(self, points): + r"""For a transitive group, finds the block system generated by + ``points``. + + Explanation + =========== + + If a group ``G`` acts on a set ``S``, a nonempty subset ``B`` of ``S`` + is called a block under the action of ``G`` if for all ``g`` in ``G`` + we have ``gB = B`` (``g`` fixes ``B``) or ``gB`` and ``B`` have no + common points (``g`` moves ``B`` entirely). ([1], p.23; [6]). + + The distinct translates ``gB`` of a block ``B`` for ``g`` in ``G`` + partition the set ``S`` and this set of translates is known as a block + system. Moreover, we obviously have that all blocks in the partition + have the same size, hence the block size divides ``|S|`` ([1], p.23). + A ``G``-congruence is an equivalence relation ``~`` on the set ``S`` + such that ``a ~ b`` implies ``g(a) ~ g(b)`` for all ``g`` in ``G``. + For a transitive group, the equivalence classes of a ``G``-congruence + and the blocks of a block system are the same thing ([1], p.23). + + The algorithm below checks the group for transitivity, and then finds + the ``G``-congruence generated by the pairs ``(p_0, p_1), (p_0, p_2), + ..., (p_0,p_{k-1})`` which is the same as finding the maximal block + system (i.e., the one with minimum block size) such that + ``p_0, ..., p_{k-1}`` are in the same block ([1], p.83). + + It is an implementation of Atkinson's algorithm, as suggested in [1], + and manipulates an equivalence relation on the set ``S`` using a + union-find data structure. The running time is just above + `O(|points||S|)`. ([1], pp. 83-87; [7]). + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import DihedralGroup + >>> D = DihedralGroup(10) + >>> D.minimal_block([0, 5]) + [0, 1, 2, 3, 4, 0, 1, 2, 3, 4] + >>> D.minimal_block([0, 1]) + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + + See Also + ======== + + _union_find_rep, _union_find_merge, is_transitive, is_primitive + + """ + if not self.is_transitive(): + return False + n = self.degree + gens = self.generators + # initialize the list of equivalence class representatives + parents = list(range(n)) + ranks = [1]*n + not_rep = [] + k = len(points) + # the block size must divide the degree of the group + if k > self.max_div: + return [0]*n + for i in range(k - 1): + parents[points[i + 1]] = points[0] + not_rep.append(points[i + 1]) + ranks[points[0]] = k + i = 0 + len_not_rep = k - 1 + while i < len_not_rep: + gamma = not_rep[i] + i += 1 + for gen in gens: + # find has side effects: performs path compression on the list + # of representatives + delta = self._union_find_rep(gamma, parents) + # union has side effects: performs union by rank on the list + # of representatives + temp = self._union_find_merge(gen(gamma), gen(delta), ranks, + parents, not_rep) + if temp == -1: + return [0]*n + len_not_rep += temp + for i in range(n): + # force path compression to get the final state of the equivalence + # relation + self._union_find_rep(i, parents) + + # rewrite result so that block representatives are minimal + new_reps = {} + return [new_reps.setdefault(r, i) for i, r in enumerate(parents)] + + def conjugacy_class(self, x): + r"""Return the conjugacy class of an element in the group. + + Explanation + =========== + + The conjugacy class of an element ``g`` in a group ``G`` is the set of + elements ``x`` in ``G`` that are conjugate with ``g``, i.e. for which + + ``g = xax^{-1}`` + + for some ``a`` in ``G``. + + Note that conjugacy is an equivalence relation, and therefore that + conjugacy classes are partitions of ``G``. For a list of all the + conjugacy classes of the group, use the conjugacy_classes() method. + + In a permutation group, each conjugacy class corresponds to a particular + `cycle structure': for example, in ``S_3``, the conjugacy classes are: + + * the identity class, ``{()}`` + * all transpositions, ``{(1 2), (1 3), (2 3)}`` + * all 3-cycles, ``{(1 2 3), (1 3 2)}`` + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, SymmetricGroup + >>> S3 = SymmetricGroup(3) + >>> S3.conjugacy_class(Permutation(0, 1, 2)) + {(0 1 2), (0 2 1)} + + Notes + ===== + + This procedure computes the conjugacy class directly by finding the + orbit of the element under conjugation in G. This algorithm is only + feasible for permutation groups of relatively small order, but is like + the orbit() function itself in that respect. + """ + # Ref: "Computing the conjugacy classes of finite groups"; Butler, G. + # Groups '93 Galway/St Andrews; edited by Campbell, C. M. + new_class = {x} + last_iteration = new_class + + while len(last_iteration) > 0: + this_iteration = set() + + for y in last_iteration: + for s in self.generators: + conjugated = s * y * (~s) + if conjugated not in new_class: + this_iteration.add(conjugated) + + new_class.update(last_iteration) + last_iteration = this_iteration + + return new_class + + + def conjugacy_classes(self): + r"""Return the conjugacy classes of the group. + + Explanation + =========== + + As described in the documentation for the .conjugacy_class() function, + conjugacy is an equivalence relation on a group G which partitions the + set of elements. This method returns a list of all these conjugacy + classes of G. + + Examples + ======== + + >>> from sympy.combinatorics import SymmetricGroup + >>> SymmetricGroup(3).conjugacy_classes() + [{(2)}, {(0 1 2), (0 2 1)}, {(0 2), (1 2), (2)(0 1)}] + + """ + identity = _af_new(list(range(self.degree))) + known_elements = {identity} + classes = [known_elements.copy()] + + for x in self.generate(): + if x not in known_elements: + new_class = self.conjugacy_class(x) + classes.append(new_class) + known_elements.update(new_class) + + return classes + + def normal_closure(self, other, k=10): + r"""Return the normal closure of a subgroup/set of permutations. + + Explanation + =========== + + If ``S`` is a subset of a group ``G``, the normal closure of ``A`` in ``G`` + is defined as the intersection of all normal subgroups of ``G`` that + contain ``A`` ([1], p.14). Alternatively, it is the group generated by + the conjugates ``x^{-1}yx`` for ``x`` a generator of ``G`` and ``y`` a + generator of the subgroup ``\left\langle S\right\rangle`` generated by + ``S`` (for some chosen generating set for ``\left\langle S\right\rangle``) + ([1], p.73). + + Parameters + ========== + + other + a subgroup/list of permutations/single permutation + k + an implementation-specific parameter that determines the number + of conjugates that are adjoined to ``other`` at once + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import (SymmetricGroup, + ... CyclicGroup, AlternatingGroup) + >>> S = SymmetricGroup(5) + >>> C = CyclicGroup(5) + >>> G = S.normal_closure(C) + >>> G.order() + 60 + >>> G.is_subgroup(AlternatingGroup(5)) + True + + See Also + ======== + + commutator, derived_subgroup, random_pr + + Notes + ===== + + The algorithm is described in [1], pp. 73-74; it makes use of the + generation of random elements for permutation groups by the product + replacement algorithm. + + """ + if hasattr(other, 'generators'): + degree = self.degree + identity = _af_new(list(range(degree))) + + if all(g == identity for g in other.generators): + return other + Z = PermutationGroup(other.generators[:]) + base, strong_gens = Z.schreier_sims_incremental() + strong_gens_distr = _distribute_gens_by_base(base, strong_gens) + basic_orbits, basic_transversals = \ + _orbits_transversals_from_bsgs(base, strong_gens_distr) + + self._random_pr_init(r=10, n=20) + + _loop = True + while _loop: + Z._random_pr_init(r=10, n=10) + for _ in range(k): + g = self.random_pr() + h = Z.random_pr() + conj = h^g + res = _strip(conj, base, basic_orbits, basic_transversals) + if res[0] != identity or res[1] != len(base) + 1: + gens = Z.generators + gens.append(conj) + Z = PermutationGroup(gens) + strong_gens.append(conj) + temp_base, temp_strong_gens = \ + Z.schreier_sims_incremental(base, strong_gens) + base, strong_gens = temp_base, temp_strong_gens + strong_gens_distr = \ + _distribute_gens_by_base(base, strong_gens) + basic_orbits, basic_transversals = \ + _orbits_transversals_from_bsgs(base, + strong_gens_distr) + _loop = False + for g in self.generators: + for h in Z.generators: + conj = h^g + res = _strip(conj, base, basic_orbits, + basic_transversals) + if res[0] != identity or res[1] != len(base) + 1: + _loop = True + break + if _loop: + break + return Z + elif hasattr(other, '__getitem__'): + return self.normal_closure(PermutationGroup(other)) + elif hasattr(other, 'array_form'): + return self.normal_closure(PermutationGroup([other])) + + def orbit(self, alpha, action='tuples'): + r"""Compute the orbit of alpha `\{g(\alpha) | g \in G\}` as a set. + + Explanation + =========== + + The time complexity of the algorithm used here is `O(|Orb|*r)` where + `|Orb|` is the size of the orbit and ``r`` is the number of generators of + the group. For a more detailed analysis, see [1], p.78, [2], pp. 19-21. + Here alpha can be a single point, or a list of points. + + If alpha is a single point, the ordinary orbit is computed. + if alpha is a list of points, there are three available options: + + 'union' - computes the union of the orbits of the points in the list + 'tuples' - computes the orbit of the list interpreted as an ordered + tuple under the group action ( i.e., g((1,2,3)) = (g(1), g(2), g(3)) ) + 'sets' - computes the orbit of the list interpreted as a sets + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> a = Permutation([1, 2, 0, 4, 5, 6, 3]) + >>> G = PermutationGroup([a]) + >>> G.orbit(0) + {0, 1, 2} + >>> G.orbit([0, 4], 'union') + {0, 1, 2, 3, 4, 5, 6} + + See Also + ======== + + orbit_transversal + + """ + return _orbit(self.degree, self.generators, alpha, action) + + def orbit_rep(self, alpha, beta, schreier_vector=None): + """Return a group element which sends ``alpha`` to ``beta``. + + Explanation + =========== + + If ``beta`` is not in the orbit of ``alpha``, the function returns + ``False``. This implementation makes use of the schreier vector. + For a proof of correctness, see [1], p.80 + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import AlternatingGroup + >>> G = AlternatingGroup(5) + >>> G.orbit_rep(0, 4) + (0 4 1 2 3) + + See Also + ======== + + schreier_vector + + """ + if schreier_vector is None: + schreier_vector = self.schreier_vector(alpha) + if schreier_vector[beta] is None: + return False + k = schreier_vector[beta] + gens = [x._array_form for x in self.generators] + a = [] + while k != -1: + a.append(gens[k]) + beta = gens[k].index(beta) # beta = (~gens[k])(beta) + k = schreier_vector[beta] + if a: + return _af_new(_af_rmuln(*a)) + else: + return _af_new(list(range(self._degree))) + + def orbit_transversal(self, alpha, pairs=False): + r"""Computes a transversal for the orbit of ``alpha`` as a set. + + Explanation + =========== + + For a permutation group `G`, a transversal for the orbit + `Orb = \{g(\alpha) | g \in G\}` is a set + `\{g_\beta | g_\beta(\alpha) = \beta\}` for `\beta \in Orb`. + Note that there may be more than one possible transversal. + If ``pairs`` is set to ``True``, it returns the list of pairs + `(\beta, g_\beta)`. For a proof of correctness, see [1], p.79 + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import DihedralGroup + >>> G = DihedralGroup(6) + >>> G.orbit_transversal(0) + [(5), (0 1 2 3 4 5), (0 5)(1 4)(2 3), (0 2 4)(1 3 5), (5)(0 4)(1 3), (0 3)(1 4)(2 5)] + + See Also + ======== + + orbit + + """ + return _orbit_transversal(self._degree, self.generators, alpha, pairs) + + def orbits(self, rep=False): + """Return the orbits of ``self``, ordered according to lowest element + in each orbit. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> a = Permutation(1, 5)(2, 3)(4, 0, 6) + >>> b = Permutation(1, 5)(3, 4)(2, 6, 0) + >>> G = PermutationGroup([a, b]) + >>> G.orbits() + [{0, 2, 3, 4, 6}, {1, 5}] + """ + return _orbits(self._degree, self._generators) + + def order(self): + """Return the order of the group: the number of permutations that + can be generated from elements of the group. + + The number of permutations comprising the group is given by + ``len(group)``; the length of each permutation in the group is + given by ``group.size``. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + + >>> a = Permutation([1, 0, 2]) + >>> G = PermutationGroup([a]) + >>> G.degree + 3 + >>> len(G) + 1 + >>> G.order() + 2 + >>> list(G.generate()) + [(2), (2)(0 1)] + + >>> a = Permutation([0, 2, 1]) + >>> b = Permutation([1, 0, 2]) + >>> G = PermutationGroup([a, b]) + >>> G.order() + 6 + + See Also + ======== + + degree + + """ + if self._order is not None: + return self._order + if self._is_sym: + n = self._degree + self._order = factorial(n) + return self._order + if self._is_alt: + n = self._degree + self._order = factorial(n)/2 + return self._order + + m = prod([len(x) for x in self.basic_transversals]) + self._order = m + return m + + def index(self, H): + """ + Returns the index of a permutation group. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> a = Permutation(1,2,3) + >>> b =Permutation(3) + >>> G = PermutationGroup([a]) + >>> H = PermutationGroup([b]) + >>> G.index(H) + 3 + + """ + if H.is_subgroup(self): + return self.order()//H.order() + + @property + def is_symmetric(self): + """Return ``True`` if the group is symmetric. + + Examples + ======== + + >>> from sympy.combinatorics import SymmetricGroup + >>> g = SymmetricGroup(5) + >>> g.is_symmetric + True + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> g = PermutationGroup( + ... Permutation(0, 1, 2, 3, 4), + ... Permutation(2, 3)) + >>> g.is_symmetric + True + + Notes + ===== + + This uses a naive test involving the computation of the full + group order. + If you need more quicker taxonomy for large groups, you can use + :meth:`PermutationGroup.is_alt_sym`. + However, :meth:`PermutationGroup.is_alt_sym` may not be accurate + and is not able to distinguish between an alternating group and + a symmetric group. + + See Also + ======== + + is_alt_sym + """ + _is_sym = self._is_sym + if _is_sym is not None: + return _is_sym + + n = self.degree + if n >= 8: + if self.is_transitive(): + _is_alt_sym = self._eval_is_alt_sym_monte_carlo() + if _is_alt_sym: + if any(g.is_odd for g in self.generators): + self._is_sym, self._is_alt = True, False + return True + + self._is_sym, self._is_alt = False, True + return False + + return self._eval_is_alt_sym_naive(only_sym=True) + + self._is_sym, self._is_alt = False, False + return False + + return self._eval_is_alt_sym_naive(only_sym=True) + + + @property + def is_alternating(self): + """Return ``True`` if the group is alternating. + + Examples + ======== + + >>> from sympy.combinatorics import AlternatingGroup + >>> g = AlternatingGroup(5) + >>> g.is_alternating + True + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> g = PermutationGroup( + ... Permutation(0, 1, 2, 3, 4), + ... Permutation(2, 3, 4)) + >>> g.is_alternating + True + + Notes + ===== + + This uses a naive test involving the computation of the full + group order. + If you need more quicker taxonomy for large groups, you can use + :meth:`PermutationGroup.is_alt_sym`. + However, :meth:`PermutationGroup.is_alt_sym` may not be accurate + and is not able to distinguish between an alternating group and + a symmetric group. + + See Also + ======== + + is_alt_sym + """ + _is_alt = self._is_alt + if _is_alt is not None: + return _is_alt + + n = self.degree + if n >= 8: + if self.is_transitive(): + _is_alt_sym = self._eval_is_alt_sym_monte_carlo() + if _is_alt_sym: + if all(g.is_even for g in self.generators): + self._is_sym, self._is_alt = False, True + return True + + self._is_sym, self._is_alt = True, False + return False + + return self._eval_is_alt_sym_naive(only_alt=True) + + self._is_sym, self._is_alt = False, False + return False + + return self._eval_is_alt_sym_naive(only_alt=True) + + @classmethod + def _distinct_primes_lemma(cls, primes): + """Subroutine to test if there is only one cyclic group for the + order.""" + primes = sorted(primes) + l = len(primes) + for i in range(l): + for j in range(i+1, l): + if primes[j] % primes[i] == 1: + return None + return True + + @property + def is_cyclic(self): + r""" + Return ``True`` if the group is Cyclic. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import AbelianGroup + >>> G = AbelianGroup(3, 4) + >>> G.is_cyclic + True + >>> G = AbelianGroup(4, 4) + >>> G.is_cyclic + False + + Notes + ===== + + If the order of a group $n$ can be factored into the distinct + primes $p_1, p_2, \dots , p_s$ and if + + .. math:: + \forall i, j \in \{1, 2, \dots, s \}: + p_i \not \equiv 1 \pmod {p_j} + + holds true, there is only one group of the order $n$ which + is a cyclic group [1]_. This is a generalization of the lemma + that the group of order $15, 35, \dots$ are cyclic. + + And also, these additional lemmas can be used to test if a + group is cyclic if the order of the group is already found. + + - If the group is abelian and the order of the group is + square-free, the group is cyclic. + - If the order of the group is less than $6$ and is not $4$, the + group is cyclic. + - If the order of the group is prime, the group is cyclic. + + References + ========== + + .. [1] 1978: John S. Rose: A Course on Group Theory, + Introduction to Finite Group Theory: 1.4 + """ + if self._is_cyclic is not None: + return self._is_cyclic + + if len(self.generators) == 1: + self._is_cyclic = True + self._is_abelian = True + return True + + if self._is_abelian is False: + self._is_cyclic = False + return False + + order = self.order() + + if order < 6: + self._is_abelian = True + if order != 4: + self._is_cyclic = True + return True + + factors = factorint(order) + if all(v == 1 for v in factors.values()): + if self._is_abelian: + self._is_cyclic = True + return True + + primes = list(factors.keys()) + if PermutationGroup._distinct_primes_lemma(primes) is True: + self._is_cyclic = True + self._is_abelian = True + return True + + if not self.is_abelian: + self._is_cyclic = False + return False + + self._is_cyclic = all( + any(g**(order//p) != self.identity for g in self.generators) + for p, e in factors.items() if e > 1 + ) + return self._is_cyclic + + @property + def is_dihedral(self): + r""" + Return ``True`` if the group is dihedral. + + Examples + ======== + + >>> from sympy.combinatorics.perm_groups import PermutationGroup + >>> from sympy.combinatorics.permutations import Permutation + >>> from sympy.combinatorics.named_groups import SymmetricGroup, CyclicGroup + >>> G = PermutationGroup(Permutation(1, 6)(2, 5)(3, 4), Permutation(0, 1, 2, 3, 4, 5, 6)) + >>> G.is_dihedral + True + >>> G = SymmetricGroup(3) + >>> G.is_dihedral + True + >>> G = CyclicGroup(6) + >>> G.is_dihedral + False + + References + ========== + + .. [Di1] https://math.stackexchange.com/questions/827230/given-a-cayley-table-is-there-an-algorithm-to-determine-if-it-is-a-dihedral-gro/827273#827273 + .. [Di2] https://kconrad.math.uconn.edu/blurbs/grouptheory/dihedral.pdf + .. [Di3] https://kconrad.math.uconn.edu/blurbs/grouptheory/dihedral2.pdf + .. [Di4] https://en.wikipedia.org/wiki/Dihedral_group + """ + if self._is_dihedral is not None: + return self._is_dihedral + + order = self.order() + + if order % 2 == 1: + self._is_dihedral = False + return False + if order == 2: + self._is_dihedral = True + return True + if order == 4: + # The dihedral group of order 4 is the Klein 4-group. + self._is_dihedral = not self.is_cyclic + return self._is_dihedral + if self.is_abelian: + # The only abelian dihedral groups are the ones of orders 2 and 4. + self._is_dihedral = False + return False + + # Now we know the group is of even order >= 6, and nonabelian. + n = order // 2 + + # Handle special cases where there are exactly two generators. + gens = self.generators + if len(gens) == 2: + x, y = gens + a, b = x.order(), y.order() + # Make a >= b + if a < b: + x, y, a, b = y, x, b, a + # Using Theorem 2.1 of [Di3]: + if a == 2 == b: + self._is_dihedral = True + return True + # Using Theorem 1.1 of [Di3]: + if a == n and b == 2 and y*x*y == ~x: + self._is_dihedral = True + return True + + # Proceed with algorithm of [Di1] + # Find elements of orders 2 and n + order_2, order_n = [], [] + for p in self.elements: + k = p.order() + if k == 2: + order_2.append(p) + elif k == n: + order_n.append(p) + + if len(order_2) != n + 1 - (n % 2): + self._is_dihedral = False + return False + + if not order_n: + self._is_dihedral = False + return False + + x = order_n[0] + # Want an element y of order 2 that is not a power of x + # (i.e. that is not the 180-deg rotation, when n is even). + y = order_2[0] + if n % 2 == 0 and y == x**(n//2): + y = order_2[1] + + self._is_dihedral = (y*x*y == ~x) + return self._is_dihedral + + def pointwise_stabilizer(self, points, incremental=True): + r"""Return the pointwise stabilizer for a set of points. + + Explanation + =========== + + For a permutation group `G` and a set of points + `\{p_1, p_2,\ldots, p_k\}`, the pointwise stabilizer of + `p_1, p_2, \ldots, p_k` is defined as + `G_{p_1,\ldots, p_k} = + \{g\in G | g(p_i) = p_i \forall i\in\{1, 2,\ldots,k\}\}` ([1],p20). + It is a subgroup of `G`. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import SymmetricGroup + >>> S = SymmetricGroup(7) + >>> Stab = S.pointwise_stabilizer([2, 3, 5]) + >>> Stab.is_subgroup(S.stabilizer(2).stabilizer(3).stabilizer(5)) + True + + See Also + ======== + + stabilizer, schreier_sims_incremental + + Notes + ===== + + When incremental == True, + rather than the obvious implementation using successive calls to + ``.stabilizer()``, this uses the incremental Schreier-Sims algorithm + to obtain a base with starting segment - the given points. + + """ + if incremental: + base, strong_gens = self.schreier_sims_incremental(base=points) + stab_gens = [] + degree = self.degree + for gen in strong_gens: + if [gen(point) for point in points] == points: + stab_gens.append(gen) + if not stab_gens: + stab_gens = _af_new(list(range(degree))) + return PermutationGroup(stab_gens) + else: + gens = self._generators + degree = self.degree + for x in points: + gens = _stabilizer(degree, gens, x) + return PermutationGroup(gens) + + def make_perm(self, n, seed=None): + """ + Multiply ``n`` randomly selected permutations from + pgroup together, starting with the identity + permutation. If ``n`` is a list of integers, those + integers will be used to select the permutations and they + will be applied in L to R order: make_perm((A, B, C)) will + give CBA(I) where I is the identity permutation. + + ``seed`` is used to set the seed for the random selection + of permutations from pgroup. If this is a list of integers, + the corresponding permutations from pgroup will be selected + in the order give. This is mainly used for testing purposes. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> a, b = [Permutation([1, 0, 3, 2]), Permutation([1, 3, 0, 2])] + >>> G = PermutationGroup([a, b]) + >>> G.make_perm(1, [0]) + (0 1)(2 3) + >>> G.make_perm(3, [0, 1, 0]) + (0 2 3 1) + >>> G.make_perm([0, 1, 0]) + (0 2 3 1) + + See Also + ======== + + random + """ + if is_sequence(n): + if seed is not None: + raise ValueError('If n is a sequence, seed should be None') + n, seed = len(n), n + else: + try: + n = int(n) + except TypeError: + raise ValueError('n must be an integer or a sequence.') + randomrange = _randrange(seed) + + # start with the identity permutation + result = Permutation(list(range(self.degree))) + m = len(self) + for _ in range(n): + p = self[randomrange(m)] + result = rmul(result, p) + return result + + def random(self, af=False): + """Return a random group element + """ + rank = randrange(self.order()) + return self.coset_unrank(rank, af) + + def random_pr(self, gen_count=11, iterations=50, _random_prec=None): + """Return a random group element using product replacement. + + Explanation + =========== + + For the details of the product replacement algorithm, see + ``_random_pr_init`` In ``random_pr`` the actual 'product replacement' + is performed. Notice that if the attribute ``_random_gens`` + is empty, it needs to be initialized by ``_random_pr_init``. + + See Also + ======== + + _random_pr_init + + """ + if self._random_gens == []: + self._random_pr_init(gen_count, iterations) + random_gens = self._random_gens + r = len(random_gens) - 1 + + # handle randomized input for testing purposes + if _random_prec is None: + s = randrange(r) + t = randrange(r - 1) + if t == s: + t = r - 1 + x = choice([1, 2]) + e = choice([-1, 1]) + else: + s = _random_prec['s'] + t = _random_prec['t'] + if t == s: + t = r - 1 + x = _random_prec['x'] + e = _random_prec['e'] + + if x == 1: + random_gens[s] = _af_rmul(random_gens[s], _af_pow(random_gens[t], e)) + random_gens[r] = _af_rmul(random_gens[r], random_gens[s]) + else: + random_gens[s] = _af_rmul(_af_pow(random_gens[t], e), random_gens[s]) + random_gens[r] = _af_rmul(random_gens[s], random_gens[r]) + return _af_new(random_gens[r]) + + def random_stab(self, alpha, schreier_vector=None, _random_prec=None): + """Random element from the stabilizer of ``alpha``. + + The schreier vector for ``alpha`` is an optional argument used + for speeding up repeated calls. The algorithm is described in [1], p.81 + + See Also + ======== + + random_pr, orbit_rep + + """ + if schreier_vector is None: + schreier_vector = self.schreier_vector(alpha) + if _random_prec is None: + rand = self.random_pr() + else: + rand = _random_prec['rand'] + beta = rand(alpha) + h = self.orbit_rep(alpha, beta, schreier_vector) + return rmul(~h, rand) + + def schreier_sims(self): + """Schreier-Sims algorithm. + + Explanation + =========== + + It computes the generators of the chain of stabilizers + `G > G_{b_1} > .. > G_{b1,..,b_r} > 1` + in which `G_{b_1,..,b_i}` stabilizes `b_1,..,b_i`, + and the corresponding ``s`` cosets. + An element of the group can be written as the product + `h_1*..*h_s`. + + We use the incremental Schreier-Sims algorithm. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> a = Permutation([0, 2, 1]) + >>> b = Permutation([1, 0, 2]) + >>> G = PermutationGroup([a, b]) + >>> G.schreier_sims() + >>> G.basic_transversals + [{0: (2)(0 1), 1: (2), 2: (1 2)}, + {0: (2), 2: (0 2)}] + """ + if self._transversals: + return + self._schreier_sims() + return + + def _schreier_sims(self, base=None): + schreier = self.schreier_sims_incremental(base=base, slp_dict=True) + base, strong_gens = schreier[:2] + self._base = base + self._strong_gens = strong_gens + self._strong_gens_slp = schreier[2] + if not base: + self._transversals = [] + self._basic_orbits = [] + return + + strong_gens_distr = _distribute_gens_by_base(base, strong_gens) + basic_orbits, transversals, slps = _orbits_transversals_from_bsgs(base,\ + strong_gens_distr, slp=True) + + # rewrite the indices stored in slps in terms of strong_gens + for i, slp in enumerate(slps): + gens = strong_gens_distr[i] + for k in slp: + slp[k] = [strong_gens.index(gens[s]) for s in slp[k]] + + self._transversals = transversals + self._basic_orbits = [sorted(x) for x in basic_orbits] + self._transversal_slp = slps + + def schreier_sims_incremental(self, base=None, gens=None, slp_dict=False): + """Extend a sequence of points and generating set to a base and strong + generating set. + + Parameters + ========== + + base + The sequence of points to be extended to a base. Optional + parameter with default value ``[]``. + gens + The generating set to be extended to a strong generating set + relative to the base obtained. Optional parameter with default + value ``self.generators``. + + slp_dict + If `True`, return a dictionary `{g: gens}` for each strong + generator `g` where `gens` is a list of strong generators + coming before `g` in `strong_gens`, such that the product + of the elements of `gens` is equal to `g`. + + Returns + ======= + + (base, strong_gens) + ``base`` is the base obtained, and ``strong_gens`` is the strong + generating set relative to it. The original parameters ``base``, + ``gens`` remain unchanged. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import AlternatingGroup + >>> from sympy.combinatorics.testutil import _verify_bsgs + >>> A = AlternatingGroup(7) + >>> base = [2, 3] + >>> seq = [2, 3] + >>> base, strong_gens = A.schreier_sims_incremental(base=seq) + >>> _verify_bsgs(A, base, strong_gens) + True + >>> base[:2] + [2, 3] + + Notes + ===== + + This version of the Schreier-Sims algorithm runs in polynomial time. + There are certain assumptions in the implementation - if the trivial + group is provided, ``base`` and ``gens`` are returned immediately, + as any sequence of points is a base for the trivial group. If the + identity is present in the generators ``gens``, it is removed as + it is a redundant generator. + The implementation is described in [1], pp. 90-93. + + See Also + ======== + + schreier_sims, schreier_sims_random + + """ + if base is None: + base = [] + if gens is None: + gens = self.generators[:] + degree = self.degree + id_af = list(range(degree)) + # handle the trivial group + if len(gens) == 1 and gens[0].is_Identity: + if slp_dict: + return base, gens, {gens[0]: [gens[0]]} + return base, gens + # prevent side effects + _base, _gens = base[:], gens[:] + # remove the identity as a generator + _gens = [x for x in _gens if not x.is_Identity] + # make sure no generator fixes all base points + for gen in _gens: + if all(x == gen._array_form[x] for x in _base): + for new in id_af: + if gen._array_form[new] != new: + break + else: + assert None # can this ever happen? + _base.append(new) + # distribute generators according to basic stabilizers + strong_gens_distr = _distribute_gens_by_base(_base, _gens) + strong_gens_slp = [] + # initialize the basic stabilizers, basic orbits and basic transversals + orbs = {} + transversals = {} + slps = {} + base_len = len(_base) + for i in range(base_len): + transversals[i], slps[i] = _orbit_transversal(degree, strong_gens_distr[i], + _base[i], pairs=True, af=True, slp=True) + transversals[i] = dict(transversals[i]) + orbs[i] = list(transversals[i].keys()) + # main loop: amend the stabilizer chain until we have generators + # for all stabilizers + i = base_len - 1 + while i >= 0: + # this flag is used to continue with the main loop from inside + # a nested loop + continue_i = False + # test the generators for being a strong generating set + db = {} + for beta, u_beta in list(transversals[i].items()): + for j, gen in enumerate(strong_gens_distr[i]): + gb = gen._array_form[beta] + u1 = transversals[i][gb] + g1 = _af_rmul(gen._array_form, u_beta) + slp = [(i, g) for g in slps[i][beta]] + slp = [(i, j)] + slp + if g1 != u1: + # test if the schreier generator is in the i+1-th + # would-be basic stabilizer + y = True + try: + u1_inv = db[gb] + except KeyError: + u1_inv = db[gb] = _af_invert(u1) + schreier_gen = _af_rmul(u1_inv, g1) + u1_inv_slp = slps[i][gb][:] + u1_inv_slp.reverse() + u1_inv_slp = [(i, (g,)) for g in u1_inv_slp] + slp = u1_inv_slp + slp + h, j, slp = _strip_af(schreier_gen, _base, orbs, transversals, i, slp=slp, slps=slps) + if j <= base_len: + # new strong generator h at level j + y = False + elif h: + # h fixes all base points + y = False + moved = 0 + while h[moved] == moved: + moved += 1 + _base.append(moved) + base_len += 1 + strong_gens_distr.append([]) + if y is False: + # if a new strong generator is found, update the + # data structures and start over + h = _af_new(h) + strong_gens_slp.append((h, slp)) + for l in range(i + 1, j): + strong_gens_distr[l].append(h) + transversals[l], slps[l] =\ + _orbit_transversal(degree, strong_gens_distr[l], + _base[l], pairs=True, af=True, slp=True) + transversals[l] = dict(transversals[l]) + orbs[l] = list(transversals[l].keys()) + i = j - 1 + # continue main loop using the flag + continue_i = True + if continue_i is True: + break + if continue_i is True: + break + if continue_i is True: + continue + i -= 1 + + strong_gens = _gens[:] + + if slp_dict: + # create the list of the strong generators strong_gens and + # rewrite the indices of strong_gens_slp in terms of the + # elements of strong_gens + for k, slp in strong_gens_slp: + strong_gens.append(k) + for i in range(len(slp)): + s = slp[i] + if isinstance(s[1], tuple): + slp[i] = strong_gens_distr[s[0]][s[1][0]]**-1 + else: + slp[i] = strong_gens_distr[s[0]][s[1]] + strong_gens_slp = dict(strong_gens_slp) + # add the original generators + for g in _gens: + strong_gens_slp[g] = [g] + return (_base, strong_gens, strong_gens_slp) + + strong_gens.extend([k for k, _ in strong_gens_slp]) + return _base, strong_gens + + def schreier_sims_random(self, base=None, gens=None, consec_succ=10, + _random_prec=None): + r"""Randomized Schreier-Sims algorithm. + + Explanation + =========== + + The randomized Schreier-Sims algorithm takes the sequence ``base`` + and the generating set ``gens``, and extends ``base`` to a base, and + ``gens`` to a strong generating set relative to that base with + probability of a wrong answer at most `2^{-consec\_succ}`, + provided the random generators are sufficiently random. + + Parameters + ========== + + base + The sequence to be extended to a base. + gens + The generating set to be extended to a strong generating set. + consec_succ + The parameter defining the probability of a wrong answer. + _random_prec + An internal parameter used for testing purposes. + + Returns + ======= + + (base, strong_gens) + ``base`` is the base and ``strong_gens`` is the strong generating + set relative to it. + + Examples + ======== + + >>> from sympy.combinatorics.testutil import _verify_bsgs + >>> from sympy.combinatorics.named_groups import SymmetricGroup + >>> S = SymmetricGroup(5) + >>> base, strong_gens = S.schreier_sims_random(consec_succ=5) + >>> _verify_bsgs(S, base, strong_gens) #doctest: +SKIP + True + + Notes + ===== + + The algorithm is described in detail in [1], pp. 97-98. It extends + the orbits ``orbs`` and the permutation groups ``stabs`` to + basic orbits and basic stabilizers for the base and strong generating + set produced in the end. + The idea of the extension process + is to "sift" random group elements through the stabilizer chain + and amend the stabilizers/orbits along the way when a sift + is not successful. + The helper function ``_strip`` is used to attempt + to decompose a random group element according to the current + state of the stabilizer chain and report whether the element was + fully decomposed (successful sift) or not (unsuccessful sift). In + the latter case, the level at which the sift failed is reported and + used to amend ``stabs``, ``base``, ``gens`` and ``orbs`` accordingly. + The halting condition is for ``consec_succ`` consecutive successful + sifts to pass. This makes sure that the current ``base`` and ``gens`` + form a BSGS with probability at least `1 - 1/\text{consec\_succ}`. + + See Also + ======== + + schreier_sims + + """ + if base is None: + base = [] + if gens is None: + gens = self.generators + base_len = len(base) + n = self.degree + # make sure no generator fixes all base points + for gen in gens: + if all(gen(x) == x for x in base): + new = 0 + while gen._array_form[new] == new: + new += 1 + base.append(new) + base_len += 1 + # distribute generators according to basic stabilizers + strong_gens_distr = _distribute_gens_by_base(base, gens) + # initialize the basic stabilizers, basic transversals and basic orbits + transversals = {} + orbs = {} + for i in range(base_len): + transversals[i] = dict(_orbit_transversal(n, strong_gens_distr[i], + base[i], pairs=True)) + orbs[i] = list(transversals[i].keys()) + # initialize the number of consecutive elements sifted + c = 0 + # start sifting random elements while the number of consecutive sifts + # is less than consec_succ + while c < consec_succ: + if _random_prec is None: + g = self.random_pr() + else: + g = _random_prec['g'].pop() + h, j = _strip(g, base, orbs, transversals) + y = True + # determine whether a new base point is needed + if j <= base_len: + y = False + elif not h.is_Identity: + y = False + moved = 0 + while h(moved) == moved: + moved += 1 + base.append(moved) + base_len += 1 + strong_gens_distr.append([]) + # if the element doesn't sift, amend the strong generators and + # associated stabilizers and orbits + if y is False: + for l in range(1, j): + strong_gens_distr[l].append(h) + transversals[l] = dict(_orbit_transversal(n, + strong_gens_distr[l], base[l], pairs=True)) + orbs[l] = list(transversals[l].keys()) + c = 0 + else: + c += 1 + # build the strong generating set + strong_gens = strong_gens_distr[0][:] + for gen in strong_gens_distr[1]: + if gen not in strong_gens: + strong_gens.append(gen) + return base, strong_gens + + def schreier_vector(self, alpha): + """Computes the schreier vector for ``alpha``. + + Explanation + =========== + + The Schreier vector efficiently stores information + about the orbit of ``alpha``. It can later be used to quickly obtain + elements of the group that send ``alpha`` to a particular element + in the orbit. Notice that the Schreier vector depends on the order + in which the group generators are listed. For a definition, see [3]. + Since list indices start from zero, we adopt the convention to use + "None" instead of 0 to signify that an element does not belong + to the orbit. + For the algorithm and its correctness, see [2], pp.78-80. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> a = Permutation([2, 4, 6, 3, 1, 5, 0]) + >>> b = Permutation([0, 1, 3, 5, 4, 6, 2]) + >>> G = PermutationGroup([a, b]) + >>> G.schreier_vector(0) + [-1, None, 0, 1, None, 1, 0] + + See Also + ======== + + orbit + + """ + n = self.degree + v = [None]*n + v[alpha] = -1 + orb = [alpha] + used = [False]*n + used[alpha] = True + gens = self.generators + r = len(gens) + for b in orb: + for i in range(r): + temp = gens[i]._array_form[b] + if used[temp] is False: + orb.append(temp) + used[temp] = True + v[temp] = i + return v + + def stabilizer(self, alpha): + r"""Return the stabilizer subgroup of ``alpha``. + + Explanation + =========== + + The stabilizer of `\alpha` is the group `G_\alpha = + \{g \in G | g(\alpha) = \alpha\}`. + For a proof of correctness, see [1], p.79. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import DihedralGroup + >>> G = DihedralGroup(6) + >>> G.stabilizer(5) + PermutationGroup([ + (5)(0 4)(1 3)]) + + See Also + ======== + + orbit + + """ + return PermGroup(_stabilizer(self._degree, self._generators, alpha)) + + @property + def strong_gens(self): + r"""Return a strong generating set from the Schreier-Sims algorithm. + + Explanation + =========== + + A generating set `S = \{g_1, g_2, \dots, g_t\}` for a permutation group + `G` is a strong generating set relative to the sequence of points + (referred to as a "base") `(b_1, b_2, \dots, b_k)` if, for + `1 \leq i \leq k` we have that the intersection of the pointwise + stabilizer `G^{(i+1)} := G_{b_1, b_2, \dots, b_i}` with `S` generates + the pointwise stabilizer `G^{(i+1)}`. The concepts of a base and + strong generating set and their applications are discussed in depth + in [1], pp. 87-89 and [2], pp. 55-57. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import DihedralGroup + >>> D = DihedralGroup(4) + >>> D.strong_gens + [(0 1 2 3), (0 3)(1 2), (1 3)] + >>> D.base + [0, 1] + + See Also + ======== + + base, basic_transversals, basic_orbits, basic_stabilizers + + """ + if self._strong_gens == []: + self.schreier_sims() + return self._strong_gens + + def subgroup(self, gens): + """ + Return the subgroup generated by `gens` which is a list of + elements of the group + """ + + if not all(g in self for g in gens): + raise ValueError("The group does not contain the supplied generators") + + G = PermutationGroup(gens) + return G + + def subgroup_search(self, prop, base=None, strong_gens=None, tests=None, + init_subgroup=None): + """Find the subgroup of all elements satisfying the property ``prop``. + + Explanation + =========== + + This is done by a depth-first search with respect to base images that + uses several tests to prune the search tree. + + Parameters + ========== + + prop + The property to be used. Has to be callable on group elements + and always return ``True`` or ``False``. It is assumed that + all group elements satisfying ``prop`` indeed form a subgroup. + base + A base for the supergroup. + strong_gens + A strong generating set for the supergroup. + tests + A list of callables of length equal to the length of ``base``. + These are used to rule out group elements by partial base images, + so that ``tests[l](g)`` returns False if the element ``g`` is known + not to satisfy prop base on where g sends the first ``l + 1`` base + points. + init_subgroup + if a subgroup of the sought group is + known in advance, it can be passed to the function as this + parameter. + + Returns + ======= + + res + The subgroup of all elements satisfying ``prop``. The generating + set for this group is guaranteed to be a strong generating set + relative to the base ``base``. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import (SymmetricGroup, + ... AlternatingGroup) + >>> from sympy.combinatorics.testutil import _verify_bsgs + >>> S = SymmetricGroup(7) + >>> prop_even = lambda x: x.is_even + >>> base, strong_gens = S.schreier_sims_incremental() + >>> G = S.subgroup_search(prop_even, base=base, strong_gens=strong_gens) + >>> G.is_subgroup(AlternatingGroup(7)) + True + >>> _verify_bsgs(G, base, G.generators) + True + + Notes + ===== + + This function is extremely lengthy and complicated and will require + some careful attention. The implementation is described in + [1], pp. 114-117, and the comments for the code here follow the lines + of the pseudocode in the book for clarity. + + The complexity is exponential in general, since the search process by + itself visits all members of the supergroup. However, there are a lot + of tests which are used to prune the search tree, and users can define + their own tests via the ``tests`` parameter, so in practice, and for + some computations, it's not terrible. + + A crucial part in the procedure is the frequent base change performed + (this is line 11 in the pseudocode) in order to obtain a new basic + stabilizer. The book mentiones that this can be done by using + ``.baseswap(...)``, however the current implementation uses a more + straightforward way to find the next basic stabilizer - calling the + function ``.stabilizer(...)`` on the previous basic stabilizer. + + """ + # initialize BSGS and basic group properties + def get_reps(orbits): + # get the minimal element in the base ordering + return [min(orbit, key = lambda x: base_ordering[x]) \ + for orbit in orbits] + + def update_nu(l): + temp_index = len(basic_orbits[l]) + 1 -\ + len(res_basic_orbits_init_base[l]) + # this corresponds to the element larger than all points + if temp_index >= len(sorted_orbits[l]): + nu[l] = base_ordering[degree] + else: + nu[l] = sorted_orbits[l][temp_index] + + if base is None: + base, strong_gens = self.schreier_sims_incremental() + base_len = len(base) + degree = self.degree + identity = _af_new(list(range(degree))) + base_ordering = _base_ordering(base, degree) + # add an element larger than all points + base_ordering.append(degree) + # add an element smaller than all points + base_ordering.append(-1) + # compute BSGS-related structures + strong_gens_distr = _distribute_gens_by_base(base, strong_gens) + basic_orbits, transversals = _orbits_transversals_from_bsgs(base, + strong_gens_distr) + # handle subgroup initialization and tests + if init_subgroup is None: + init_subgroup = PermutationGroup([identity]) + if tests is None: + trivial_test = lambda x: True + tests = [] + for i in range(base_len): + tests.append(trivial_test) + # line 1: more initializations. + res = init_subgroup + f = base_len - 1 + l = base_len - 1 + # line 2: set the base for K to the base for G + res_base = base[:] + # line 3: compute BSGS and related structures for K + res_base, res_strong_gens = res.schreier_sims_incremental( + base=res_base) + res_strong_gens_distr = _distribute_gens_by_base(res_base, + res_strong_gens) + res_generators = res.generators + res_basic_orbits_init_base = \ + [_orbit(degree, res_strong_gens_distr[i], res_base[i])\ + for i in range(base_len)] + # initialize orbit representatives + orbit_reps = [None]*base_len + # line 4: orbit representatives for f-th basic stabilizer of K + orbits = _orbits(degree, res_strong_gens_distr[f]) + orbit_reps[f] = get_reps(orbits) + # line 5: remove the base point from the representatives to avoid + # getting the identity element as a generator for K + orbit_reps[f].remove(base[f]) + # line 6: more initializations + c = [0]*base_len + u = [identity]*base_len + sorted_orbits = [None]*base_len + for i in range(base_len): + sorted_orbits[i] = basic_orbits[i][:] + sorted_orbits[i].sort(key=lambda point: base_ordering[point]) + # line 7: initializations + mu = [None]*base_len + nu = [None]*base_len + # this corresponds to the element smaller than all points + mu[l] = degree + 1 + update_nu(l) + # initialize computed words + computed_words = [identity]*base_len + # line 8: main loop + while True: + # apply all the tests + while l < base_len - 1 and \ + computed_words[l](base[l]) in orbit_reps[l] and \ + base_ordering[mu[l]] < \ + base_ordering[computed_words[l](base[l])] < \ + base_ordering[nu[l]] and \ + tests[l](computed_words): + # line 11: change the (partial) base of K + new_point = computed_words[l](base[l]) + res_base[l] = new_point + new_stab_gens = _stabilizer(degree, res_strong_gens_distr[l], + new_point) + res_strong_gens_distr[l + 1] = new_stab_gens + # line 12: calculate minimal orbit representatives for the + # l+1-th basic stabilizer + orbits = _orbits(degree, new_stab_gens) + orbit_reps[l + 1] = get_reps(orbits) + # line 13: amend sorted orbits + l += 1 + temp_orbit = [computed_words[l - 1](point) for point + in basic_orbits[l]] + temp_orbit.sort(key=lambda point: base_ordering[point]) + sorted_orbits[l] = temp_orbit + # lines 14 and 15: update variables used minimality tests + new_mu = degree + 1 + for i in range(l): + if base[l] in res_basic_orbits_init_base[i]: + candidate = computed_words[i](base[i]) + if base_ordering[candidate] > base_ordering[new_mu]: + new_mu = candidate + mu[l] = new_mu + update_nu(l) + # line 16: determine the new transversal element + c[l] = 0 + temp_point = sorted_orbits[l][c[l]] + gamma = computed_words[l - 1]._array_form.index(temp_point) + u[l] = transversals[l][gamma] + # update computed words + computed_words[l] = rmul(computed_words[l - 1], u[l]) + # lines 17 & 18: apply the tests to the group element found + g = computed_words[l] + temp_point = g(base[l]) + if l == base_len - 1 and \ + base_ordering[mu[l]] < \ + base_ordering[temp_point] < base_ordering[nu[l]] and \ + temp_point in orbit_reps[l] and \ + tests[l](computed_words) and \ + prop(g): + # line 19: reset the base of K + res_generators.append(g) + res_base = base[:] + # line 20: recalculate basic orbits (and transversals) + res_strong_gens.append(g) + res_strong_gens_distr = _distribute_gens_by_base(res_base, + res_strong_gens) + res_basic_orbits_init_base = \ + [_orbit(degree, res_strong_gens_distr[i], res_base[i]) \ + for i in range(base_len)] + # line 21: recalculate orbit representatives + # line 22: reset the search depth + orbit_reps[f] = get_reps(orbits) + l = f + # line 23: go up the tree until in the first branch not fully + # searched + while l >= 0 and c[l] == len(basic_orbits[l]) - 1: + l = l - 1 + # line 24: if the entire tree is traversed, return K + if l == -1: + return PermutationGroup(res_generators) + # lines 25-27: update orbit representatives + if l < f: + # line 26 + f = l + c[l] = 0 + # line 27 + temp_orbits = _orbits(degree, res_strong_gens_distr[f]) + orbit_reps[f] = get_reps(temp_orbits) + # line 28: update variables used for minimality testing + mu[l] = degree + 1 + temp_index = len(basic_orbits[l]) + 1 - \ + len(res_basic_orbits_init_base[l]) + if temp_index >= len(sorted_orbits[l]): + nu[l] = base_ordering[degree] + else: + nu[l] = sorted_orbits[l][temp_index] + # line 29: set the next element from the current branch and update + # accordingly + c[l] += 1 + if l == 0: + gamma = sorted_orbits[l][c[l]] + else: + gamma = computed_words[l - 1]._array_form.index(sorted_orbits[l][c[l]]) + + u[l] = transversals[l][gamma] + if l == 0: + computed_words[l] = u[l] + else: + computed_words[l] = rmul(computed_words[l - 1], u[l]) + + @property + def transitivity_degree(self): + r"""Compute the degree of transitivity of the group. + + Explanation + =========== + + A permutation group `G` acting on `\Omega = \{0, 1, \dots, n-1\}` is + ``k``-fold transitive, if, for any `k` points + `(a_1, a_2, \dots, a_k) \in \Omega` and any `k` points + `(b_1, b_2, \dots, b_k) \in \Omega` there exists `g \in G` such that + `g(a_1) = b_1, g(a_2) = b_2, \dots, g(a_k) = b_k` + The degree of transitivity of `G` is the maximum ``k`` such that + `G` is ``k``-fold transitive. ([8]) + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> a = Permutation([1, 2, 0]) + >>> b = Permutation([1, 0, 2]) + >>> G = PermutationGroup([a, b]) + >>> G.transitivity_degree + 3 + + See Also + ======== + + is_transitive, orbit + + """ + if self._transitivity_degree is None: + n = self.degree + G = self + # if G is k-transitive, a tuple (a_0,..,a_k) + # can be brought to (b_0,...,b_(k-1), b_k) + # where b_0,...,b_(k-1) are fixed points; + # consider the group G_k which stabilizes b_0,...,b_(k-1) + # if G_k is transitive on the subset excluding b_0,...,b_(k-1) + # then G is (k+1)-transitive + for i in range(n): + orb = G.orbit(i) + if len(orb) != n - i: + self._transitivity_degree = i + return i + G = G.stabilizer(i) + self._transitivity_degree = n + return n + else: + return self._transitivity_degree + + def _p_elements_group(self, p): + ''' + For an abelian p-group, return the subgroup consisting of + all elements of order p (and the identity) + + ''' + gens = self.generators[:] + gens = sorted(gens, key=lambda x: x.order(), reverse=True) + gens_p = [g**(g.order()/p) for g in gens] + gens_r = [] + for i in range(len(gens)): + x = gens[i] + x_order = x.order() + # x_p has order p + x_p = x**(x_order/p) + if i > 0: + P = PermutationGroup(gens_p[:i]) + else: + P = PermutationGroup(self.identity) + if x**(x_order/p) not in P: + gens_r.append(x**(x_order/p)) + else: + # replace x by an element of order (x.order()/p) + # so that gens still generates G + g = P.generator_product(x_p, original=True) + for s in g: + x = x*s**-1 + x_order = x_order/p + # insert x to gens so that the sorting is preserved + del gens[i] + del gens_p[i] + j = i - 1 + while j < len(gens) and gens[j].order() >= x_order: + j += 1 + gens = gens[:j] + [x] + gens[j:] + gens_p = gens_p[:j] + [x] + gens_p[j:] + return PermutationGroup(gens_r) + + def _sylow_alt_sym(self, p): + ''' + Return a p-Sylow subgroup of a symmetric or an + alternating group. + + Explanation + =========== + + The algorithm for this is hinted at in [1], Chapter 4, + Exercise 4. + + For Sym(n) with n = p^i, the idea is as follows. Partition + the interval [0..n-1] into p equal parts, each of length p^(i-1): + [0..p^(i-1)-1], [p^(i-1)..2*p^(i-1)-1]...[(p-1)*p^(i-1)..p^i-1]. + Find a p-Sylow subgroup of Sym(p^(i-1)) (treated as a subgroup + of ``self``) acting on each of the parts. Call the subgroups + P_1, P_2...P_p. The generators for the subgroups P_2...P_p + can be obtained from those of P_1 by applying a "shifting" + permutation to them, that is, a permutation mapping [0..p^(i-1)-1] + to the second part (the other parts are obtained by using the shift + multiple times). The union of this permutation and the generators + of P_1 is a p-Sylow subgroup of ``self``. + + For n not equal to a power of p, partition + [0..n-1] in accordance with how n would be written in base p. + E.g. for p=2 and n=11, 11 = 2^3 + 2^2 + 1 so the partition + is [[0..7], [8..9], {10}]. To generate a p-Sylow subgroup, + take the union of the generators for each of the parts. + For the above example, {(0 1), (0 2)(1 3), (0 4), (1 5)(2 7)} + from the first part, {(8 9)} from the second part and + nothing from the third. This gives 4 generators in total, and + the subgroup they generate is p-Sylow. + + Alternating groups are treated the same except when p=2. In this + case, (0 1)(s s+1) should be added for an appropriate s (the start + of a part) for each part in the partitions. + + See Also + ======== + + sylow_subgroup, is_alt_sym + + ''' + n = self.degree + gens = [] + identity = Permutation(n-1) + # the case of 2-sylow subgroups of alternating groups + # needs special treatment + alt = p == 2 and all(g.is_even for g in self.generators) + + # find the presentation of n in base p + coeffs = [] + m = n + while m > 0: + coeffs.append(m % p) + m = m // p + + power = len(coeffs)-1 + # for a symmetric group, gens[:i] is the generating + # set for a p-Sylow subgroup on [0..p**(i-1)-1]. For + # alternating groups, the same is given by gens[:2*(i-1)] + for i in range(1, power+1): + if i == 1 and alt: + # (0 1) shouldn't be added for alternating groups + continue + gen = Permutation([(j + p**(i-1)) % p**i for j in range(p**i)]) + gens.append(identity*gen) + if alt: + gen = Permutation(0, 1)*gen*Permutation(0, 1)*gen + gens.append(gen) + + # the first point in the current part (see the algorithm + # description in the docstring) + start = 0 + + while power > 0: + a = coeffs[power] + + # make the permutation shifting the start of the first + # part ([0..p^i-1] for some i) to the current one + for _ in range(a): + shift = Permutation() + if start > 0: + for i in range(p**power): + shift = shift(i, start + i) + + if alt: + gen = Permutation(0, 1)*shift*Permutation(0, 1)*shift + gens.append(gen) + j = 2*(power - 1) + else: + j = power + + for i, gen in enumerate(gens[:j]): + if alt and i % 2 == 1: + continue + # shift the generator to the start of the + # partition part + gen = shift*gen*shift + gens.append(gen) + + start += p**power + power = power-1 + + return gens + + def sylow_subgroup(self, p): + ''' + Return a p-Sylow subgroup of the group. + + The algorithm is described in [1], Chapter 4, Section 7 + + Examples + ======== + >>> from sympy.combinatorics.named_groups import DihedralGroup + >>> from sympy.combinatorics.named_groups import SymmetricGroup + >>> from sympy.combinatorics.named_groups import AlternatingGroup + + >>> D = DihedralGroup(6) + >>> S = D.sylow_subgroup(2) + >>> S.order() + 4 + >>> G = SymmetricGroup(6) + >>> S = G.sylow_subgroup(5) + >>> S.order() + 5 + + >>> G1 = AlternatingGroup(3) + >>> G2 = AlternatingGroup(5) + >>> G3 = AlternatingGroup(9) + + >>> S1 = G1.sylow_subgroup(3) + >>> S2 = G2.sylow_subgroup(3) + >>> S3 = G3.sylow_subgroup(3) + + >>> len1 = len(S1.lower_central_series()) + >>> len2 = len(S2.lower_central_series()) + >>> len3 = len(S3.lower_central_series()) + + >>> len1 == len2 + True + >>> len1 < len3 + True + + ''' + from sympy.combinatorics.homomorphisms import ( + orbit_homomorphism, block_homomorphism) + + if not isprime(p): + raise ValueError("p must be a prime") + + def is_p_group(G): + # check if the order of G is a power of p + # and return the power + m = G.order() + n = 0 + while m % p == 0: + m = m/p + n += 1 + if m == 1: + return True, n + return False, n + + def _sylow_reduce(mu, nu): + # reduction based on two homomorphisms + # mu and nu with trivially intersecting + # kernels + Q = mu.image().sylow_subgroup(p) + Q = mu.invert_subgroup(Q) + nu = nu.restrict_to(Q) + R = nu.image().sylow_subgroup(p) + return nu.invert_subgroup(R) + + order = self.order() + if order % p != 0: + return PermutationGroup([self.identity]) + p_group, n = is_p_group(self) + if p_group: + return self + + if self.is_alt_sym(): + return PermutationGroup(self._sylow_alt_sym(p)) + + # if there is a non-trivial orbit with size not divisible + # by p, the sylow subgroup is contained in its stabilizer + # (by orbit-stabilizer theorem) + orbits = self.orbits() + non_p_orbits = [o for o in orbits if len(o) % p != 0 and len(o) != 1] + if non_p_orbits: + G = self.stabilizer(list(non_p_orbits[0]).pop()) + return G.sylow_subgroup(p) + + if not self.is_transitive(): + # apply _sylow_reduce to orbit actions + orbits = sorted(orbits, key=len) + omega1 = orbits.pop() + omega2 = orbits[0].union(*orbits) + mu = orbit_homomorphism(self, omega1) + nu = orbit_homomorphism(self, omega2) + return _sylow_reduce(mu, nu) + + blocks = self.minimal_blocks() + if len(blocks) > 1: + # apply _sylow_reduce to block system actions + mu = block_homomorphism(self, blocks[0]) + nu = block_homomorphism(self, blocks[1]) + return _sylow_reduce(mu, nu) + elif len(blocks) == 1: + block = list(blocks)[0] + if any(e != 0 for e in block): + # self is imprimitive + mu = block_homomorphism(self, block) + if not is_p_group(mu.image())[0]: + S = mu.image().sylow_subgroup(p) + return mu.invert_subgroup(S).sylow_subgroup(p) + + # find an element of order p + g = self.random() + g_order = g.order() + while g_order % p != 0 or g_order == 0: + g = self.random() + g_order = g.order() + g = g**(g_order // p) + if order % p**2 != 0: + return PermutationGroup(g) + + C = self.centralizer(g) + while C.order() % p**n != 0: + S = C.sylow_subgroup(p) + s_order = S.order() + Z = S.center() + P = Z._p_elements_group(p) + h = P.random() + C_h = self.centralizer(h) + while C_h.order() % p*s_order != 0: + h = P.random() + C_h = self.centralizer(h) + C = C_h + + return C.sylow_subgroup(p) + + def _block_verify(self, L, alpha): + delta = sorted(self.orbit(alpha)) + # p[i] will be the number of the block + # delta[i] belongs to + p = [-1]*len(delta) + blocks = [-1]*len(delta) + + B = [[]] # future list of blocks + u = [0]*len(delta) # u[i] in L s.t. alpha^u[i] = B[0][i] + + t = L.orbit_transversal(alpha, pairs=True) + for a, beta in t: + B[0].append(a) + i_a = delta.index(a) + p[i_a] = 0 + blocks[i_a] = alpha + u[i_a] = beta + + rho = 0 + m = 0 # number of blocks - 1 + + while rho <= m: + beta = B[rho][0] + for g in self.generators: + d = beta^g + i_d = delta.index(d) + sigma = p[i_d] + if sigma < 0: + # define a new block + m += 1 + sigma = m + u[i_d] = u[delta.index(beta)]*g + p[i_d] = sigma + rep = d + blocks[i_d] = rep + newb = [rep] + for gamma in B[rho][1:]: + i_gamma = delta.index(gamma) + d = gamma^g + i_d = delta.index(d) + if p[i_d] < 0: + u[i_d] = u[i_gamma]*g + p[i_d] = sigma + blocks[i_d] = rep + newb.append(d) + else: + # B[rho] is not a block + s = u[i_gamma]*g*u[i_d]**(-1) + return False, s + + B.append(newb) + else: + for h in B[rho][1:]: + if h^g not in B[sigma]: + # B[rho] is not a block + s = u[delta.index(beta)]*g*u[i_d]**(-1) + return False, s + rho += 1 + + return True, blocks + + def _verify(H, K, phi, z, alpha): + ''' + Return a list of relators ``rels`` in generators ``gens`_h` that + are mapped to ``H.generators`` by ``phi`` so that given a finite + presentation of ``K`` on a subset of ``gens_h`` + is a finite presentation of ``H``. + + Explanation + =========== + + ``H`` should be generated by the union of ``K.generators`` and ``z`` + (a single generator), and ``H.stabilizer(alpha) == K``; ``phi`` is a + canonical injection from a free group into a permutation group + containing ``H``. + + The algorithm is described in [1], Chapter 6. + + Examples + ======== + + >>> from sympy.combinatorics import free_group, Permutation, PermutationGroup + >>> from sympy.combinatorics.homomorphisms import homomorphism + >>> from sympy.combinatorics.fp_groups import FpGroup + + >>> H = PermutationGroup(Permutation(0, 2), Permutation (1, 5)) + >>> K = PermutationGroup(Permutation(5)(0, 2)) + >>> F = free_group("x_0 x_1")[0] + >>> gens = F.generators + >>> phi = homomorphism(F, H, F.generators, H.generators) + >>> rels_k = [gens[0]**2] # relators for presentation of K + >>> z= Permutation(1, 5) + >>> check, rels_h = H._verify(K, phi, z, 1) + >>> check + True + >>> rels = rels_k + rels_h + >>> G = FpGroup(F, rels) # presentation of H + >>> G.order() == H.order() + True + + See also + ======== + + strong_presentation, presentation, stabilizer + + ''' + + orbit = H.orbit(alpha) + beta = alpha^(z**-1) + + K_beta = K.stabilizer(beta) + + # orbit representatives of K_beta + gammas = [alpha, beta] + orbits = list({tuple(K_beta.orbit(o)) for o in orbit}) + orbit_reps = [orb[0] for orb in orbits] + for rep in orbit_reps: + if rep not in gammas: + gammas.append(rep) + + # orbit transversal of K + betas = [alpha, beta] + transversal = {alpha: phi.invert(H.identity), beta: phi.invert(z**-1)} + + for s, g in K.orbit_transversal(beta, pairs=True): + if s not in transversal: + transversal[s] = transversal[beta]*phi.invert(g) + + + union = K.orbit(alpha).union(K.orbit(beta)) + while (len(union) < len(orbit)): + for gamma in gammas: + if gamma in union: + r = gamma^z + if r not in union: + betas.append(r) + transversal[r] = transversal[gamma]*phi.invert(z) + for s, g in K.orbit_transversal(r, pairs=True): + if s not in transversal: + transversal[s] = transversal[r]*phi.invert(g) + union = union.union(K.orbit(r)) + break + + # compute relators + rels = [] + + for b in betas: + k_gens = K.stabilizer(b).generators + for y in k_gens: + new_rel = transversal[b] + gens = K.generator_product(y, original=True) + for g in gens[::-1]: + new_rel = new_rel*phi.invert(g) + new_rel = new_rel*transversal[b]**-1 + + perm = phi(new_rel) + try: + gens = K.generator_product(perm, original=True) + except ValueError: + return False, perm + for g in gens: + new_rel = new_rel*phi.invert(g)**-1 + if new_rel not in rels: + rels.append(new_rel) + + for gamma in gammas: + new_rel = transversal[gamma]*phi.invert(z)*transversal[gamma^z]**-1 + perm = phi(new_rel) + try: + gens = K.generator_product(perm, original=True) + except ValueError: + return False, perm + for g in gens: + new_rel = new_rel*phi.invert(g)**-1 + if new_rel not in rels: + rels.append(new_rel) + + return True, rels + + def strong_presentation(self): + ''' + Return a strong finite presentation of group. The generators + of the returned group are in the same order as the strong + generators of group. + + The algorithm is based on Sims' Verify algorithm described + in [1], Chapter 6. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import DihedralGroup + >>> P = DihedralGroup(4) + >>> G = P.strong_presentation() + >>> P.order() == G.order() + True + + See Also + ======== + + presentation, _verify + + ''' + from sympy.combinatorics.fp_groups import (FpGroup, + simplify_presentation) + from sympy.combinatorics.free_groups import free_group + from sympy.combinatorics.homomorphisms import (block_homomorphism, + homomorphism, GroupHomomorphism) + + strong_gens = self.strong_gens[:] + stabs = self.basic_stabilizers[:] + base = self.base[:] + + # injection from a free group on len(strong_gens) + # generators into G + gen_syms = [('x_%d'%i) for i in range(len(strong_gens))] + F = free_group(', '.join(gen_syms))[0] + phi = homomorphism(F, self, F.generators, strong_gens) + + H = PermutationGroup(self.identity) + while stabs: + alpha = base.pop() + K = H + H = stabs.pop() + new_gens = [g for g in H.generators if g not in K] + + if K.order() == 1: + z = new_gens.pop() + rels = [F.generators[-1]**z.order()] + intermediate_gens = [z] + K = PermutationGroup(intermediate_gens) + + # add generators one at a time building up from K to H + while new_gens: + z = new_gens.pop() + intermediate_gens = [z] + intermediate_gens + K_s = PermutationGroup(intermediate_gens) + orbit = K_s.orbit(alpha) + orbit_k = K.orbit(alpha) + + # split into cases based on the orbit of K_s + if orbit_k == orbit: + if z in K: + rel = phi.invert(z) + perm = z + else: + t = K.orbit_rep(alpha, alpha^z) + rel = phi.invert(z)*phi.invert(t)**-1 + perm = z*t**-1 + for g in K.generator_product(perm, original=True): + rel = rel*phi.invert(g)**-1 + new_rels = [rel] + elif len(orbit_k) == 1: + # `success` is always true because `strong_gens` + # and `base` are already a verified BSGS. Later + # this could be changed to start with a randomly + # generated (potential) BSGS, and then new elements + # would have to be appended to it when `success` + # is false. + success, new_rels = K_s._verify(K, phi, z, alpha) + else: + # K.orbit(alpha) should be a block + # under the action of K_s on K_s.orbit(alpha) + check, block = K_s._block_verify(K, alpha) + if check: + # apply _verify to the action of K_s + # on the block system; for convenience, + # add the blocks as additional points + # that K_s should act on + t = block_homomorphism(K_s, block) + m = t.codomain.degree # number of blocks + d = K_s.degree + + # conjugating with p will shift + # permutations in t.image() to + # higher numbers, e.g. + # p*(0 1)*p = (m m+1) + p = Permutation() + for i in range(m): + p *= Permutation(i, i+d) + + t_img = t.images + # combine generators of K_s with their + # action on the block system + images = {g: g*p*t_img[g]*p for g in t_img} + for g in self.strong_gens[:-len(K_s.generators)]: + images[g] = g + K_s_act = PermutationGroup(list(images.values())) + f = GroupHomomorphism(self, K_s_act, images) + + K_act = PermutationGroup([f(g) for g in K.generators]) + success, new_rels = K_s_act._verify(K_act, f.compose(phi), f(z), d) + + for n in new_rels: + if n not in rels: + rels.append(n) + K = K_s + + group = FpGroup(F, rels) + return simplify_presentation(group) + + def presentation(self, eliminate_gens=True): + ''' + Return an `FpGroup` presentation of the group. + + The algorithm is described in [1], Chapter 6.1. + + ''' + from sympy.combinatorics.fp_groups import (FpGroup, + simplify_presentation) + from sympy.combinatorics.coset_table import CosetTable + from sympy.combinatorics.free_groups import free_group + from sympy.combinatorics.homomorphisms import homomorphism + + if self._fp_presentation: + return self._fp_presentation + + def _factor_group_by_rels(G, rels): + if isinstance(G, FpGroup): + rels.extend(G.relators) + return FpGroup(G.free_group, list(set(rels))) + return FpGroup(G, rels) + + gens = self.generators + len_g = len(gens) + + if len_g == 1: + order = gens[0].order() + # handle the trivial group + if order == 1: + return free_group([])[0] + F, x = free_group('x') + return FpGroup(F, [x**order]) + + if self.order() > 20: + half_gens = self.generators[0:(len_g+1)//2] + else: + half_gens = [] + H = PermutationGroup(half_gens) + H_p = H.presentation() + + len_h = len(H_p.generators) + + C = self.coset_table(H) + n = len(C) # subgroup index + + gen_syms = [('x_%d'%i) for i in range(len(gens))] + F = free_group(', '.join(gen_syms))[0] + + # mapping generators of H_p to those of F + images = [F.generators[i] for i in range(len_h)] + R = homomorphism(H_p, F, H_p.generators, images, check=False) + + # rewrite relators + rels = R(H_p.relators) + G_p = FpGroup(F, rels) + + # injective homomorphism from G_p into self + T = homomorphism(G_p, self, G_p.generators, gens) + + C_p = CosetTable(G_p, []) + + C_p.table = [[None]*(2*len_g) for i in range(n)] + + # initiate the coset transversal + transversal = [None]*n + transversal[0] = G_p.identity + + # fill in the coset table as much as possible + for i in range(2*len_h): + C_p.table[0][i] = 0 + + gamma = 1 + for alpha, x in product(range(n), range(2*len_g)): + beta = C[alpha][x] + if beta == gamma: + gen = G_p.generators[x//2]**((-1)**(x % 2)) + transversal[beta] = transversal[alpha]*gen + C_p.table[alpha][x] = beta + C_p.table[beta][x + (-1)**(x % 2)] = alpha + gamma += 1 + if gamma == n: + break + + C_p.p = list(range(n)) + beta = x = 0 + + while not C_p.is_complete(): + # find the first undefined entry + while C_p.table[beta][x] == C[beta][x]: + x = (x + 1) % (2*len_g) + if x == 0: + beta = (beta + 1) % n + + # define a new relator + gen = G_p.generators[x//2]**((-1)**(x % 2)) + new_rel = transversal[beta]*gen*transversal[C[beta][x]]**-1 + perm = T(new_rel) + nxt = G_p.identity + for s in H.generator_product(perm, original=True): + nxt = nxt*T.invert(s)**-1 + new_rel = new_rel*nxt + + # continue coset enumeration + G_p = _factor_group_by_rels(G_p, [new_rel]) + C_p.scan_and_fill(0, new_rel) + C_p = G_p.coset_enumeration([], strategy="coset_table", + draft=C_p, max_cosets=n, incomplete=True) + + self._fp_presentation = simplify_presentation(G_p) + return self._fp_presentation + + def polycyclic_group(self): + """ + Return the PolycyclicGroup instance with below parameters: + + Explanation + =========== + + * pc_sequence : Polycyclic sequence is formed by collecting all + the missing generators between the adjacent groups in the + derived series of given permutation group. + + * pc_series : Polycyclic series is formed by adding all the missing + generators of ``der[i+1]`` in ``der[i]``, where ``der`` represents + the derived series. + + * relative_order : A list, computed by the ratio of adjacent groups in + pc_series. + + """ + from sympy.combinatorics.pc_groups import PolycyclicGroup + if not self.is_polycyclic: + raise ValueError("The group must be solvable") + + der = self.derived_series() + pc_series = [] + pc_sequence = [] + relative_order = [] + pc_series.append(der[-1]) + der.reverse() + + for i in range(len(der)-1): + H = der[i] + for g in der[i+1].generators: + if g not in H: + H = PermutationGroup([g] + H.generators) + pc_series.insert(0, H) + pc_sequence.insert(0, g) + + G1 = pc_series[0].order() + G2 = pc_series[1].order() + relative_order.insert(0, G1 // G2) + + return PolycyclicGroup(pc_sequence, pc_series, relative_order, collector=None) + + +def _orbit(degree, generators, alpha, action='tuples'): + r"""Compute the orbit of alpha `\{g(\alpha) | g \in G\}` as a set. + + Explanation + =========== + + The time complexity of the algorithm used here is `O(|Orb|*r)` where + `|Orb|` is the size of the orbit and ``r`` is the number of generators of + the group. For a more detailed analysis, see [1], p.78, [2], pp. 19-21. + Here alpha can be a single point, or a list of points. + + If alpha is a single point, the ordinary orbit is computed. + if alpha is a list of points, there are three available options: + + 'union' - computes the union of the orbits of the points in the list + 'tuples' - computes the orbit of the list interpreted as an ordered + tuple under the group action ( i.e., g((1, 2, 3)) = (g(1), g(2), g(3)) ) + 'sets' - computes the orbit of the list interpreted as a sets + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> from sympy.combinatorics.perm_groups import _orbit + >>> a = Permutation([1, 2, 0, 4, 5, 6, 3]) + >>> G = PermutationGroup([a]) + >>> _orbit(G.degree, G.generators, 0) + {0, 1, 2} + >>> _orbit(G.degree, G.generators, [0, 4], 'union') + {0, 1, 2, 3, 4, 5, 6} + + See Also + ======== + + orbit, orbit_transversal + + """ + if not hasattr(alpha, '__getitem__'): + alpha = [alpha] + + gens = [x._array_form for x in generators] + if len(alpha) == 1 or action == 'union': + orb = alpha + used = [False]*degree + for el in alpha: + used[el] = True + for b in orb: + for gen in gens: + temp = gen[b] + if used[temp] == False: + orb.append(temp) + used[temp] = True + return set(orb) + elif action == 'tuples': + alpha = tuple(alpha) + orb = [alpha] + used = {alpha} + for b in orb: + for gen in gens: + temp = tuple([gen[x] for x in b]) + if temp not in used: + orb.append(temp) + used.add(temp) + return set(orb) + elif action == 'sets': + alpha = frozenset(alpha) + orb = [alpha] + used = {alpha} + for b in orb: + for gen in gens: + temp = frozenset([gen[x] for x in b]) + if temp not in used: + orb.append(temp) + used.add(temp) + return {tuple(x) for x in orb} + + +def _orbits(degree, generators): + """Compute the orbits of G. + + If ``rep=False`` it returns a list of sets else it returns a list of + representatives of the orbits + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> from sympy.combinatorics.perm_groups import _orbits + >>> a = Permutation([0, 2, 1]) + >>> b = Permutation([1, 0, 2]) + >>> _orbits(a.size, [a, b]) + [{0, 1, 2}] + """ + + orbs = [] + sorted_I = list(range(degree)) + I = set(sorted_I) + while I: + i = sorted_I[0] + orb = _orbit(degree, generators, i) + orbs.append(orb) + # remove all indices that are in this orbit + I -= orb + sorted_I = [i for i in sorted_I if i not in orb] + return orbs + + +def _orbit_transversal(degree, generators, alpha, pairs, af=False, slp=False): + r"""Computes a transversal for the orbit of ``alpha`` as a set. + + Explanation + =========== + + generators generators of the group ``G`` + + For a permutation group ``G``, a transversal for the orbit + `Orb = \{g(\alpha) | g \in G\}` is a set + `\{g_\beta | g_\beta(\alpha) = \beta\}` for `\beta \in Orb`. + Note that there may be more than one possible transversal. + If ``pairs`` is set to ``True``, it returns the list of pairs + `(\beta, g_\beta)`. For a proof of correctness, see [1], p.79 + + if ``af`` is ``True``, the transversal elements are given in + array form. + + If `slp` is `True`, a dictionary `{beta: slp_beta}` is returned + for `\beta \in Orb` where `slp_beta` is a list of indices of the + generators in `generators` s.t. if `slp_beta = [i_1 \dots i_n]` + `g_\beta = generators[i_n] \times \dots \times generators[i_1]`. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import DihedralGroup + >>> from sympy.combinatorics.perm_groups import _orbit_transversal + >>> G = DihedralGroup(6) + >>> _orbit_transversal(G.degree, G.generators, 0, False) + [(5), (0 1 2 3 4 5), (0 5)(1 4)(2 3), (0 2 4)(1 3 5), (5)(0 4)(1 3), (0 3)(1 4)(2 5)] + """ + + tr = [(alpha, list(range(degree)))] + slp_dict = {alpha: []} + used = [False]*degree + used[alpha] = True + gens = [x._array_form for x in generators] + for x, px in tr: + px_slp = slp_dict[x] + for gen in gens: + temp = gen[x] + if used[temp] == False: + slp_dict[temp] = [gens.index(gen)] + px_slp + tr.append((temp, _af_rmul(gen, px))) + used[temp] = True + if pairs: + if not af: + tr = [(x, _af_new(y)) for x, y in tr] + if not slp: + return tr + return tr, slp_dict + + if af: + tr = [y for _, y in tr] + if not slp: + return tr + return tr, slp_dict + + tr = [_af_new(y) for _, y in tr] + if not slp: + return tr + return tr, slp_dict + + +def _stabilizer(degree, generators, alpha): + r"""Return the stabilizer subgroup of ``alpha``. + + Explanation + =========== + + The stabilizer of `\alpha` is the group `G_\alpha = + \{g \in G | g(\alpha) = \alpha\}`. + For a proof of correctness, see [1], p.79. + + degree : degree of G + generators : generators of G + + Examples + ======== + + >>> from sympy.combinatorics.perm_groups import _stabilizer + >>> from sympy.combinatorics.named_groups import DihedralGroup + >>> G = DihedralGroup(6) + >>> _stabilizer(G.degree, G.generators, 5) + [(5)(0 4)(1 3), (5)] + + See Also + ======== + + orbit + + """ + orb = [alpha] + table = {alpha: list(range(degree))} + table_inv = {alpha: list(range(degree))} + used = [False]*degree + used[alpha] = True + gens = [x._array_form for x in generators] + stab_gens = [] + for b in orb: + for gen in gens: + temp = gen[b] + if used[temp] is False: + gen_temp = _af_rmul(gen, table[b]) + orb.append(temp) + table[temp] = gen_temp + table_inv[temp] = _af_invert(gen_temp) + used[temp] = True + else: + schreier_gen = _af_rmuln(table_inv[temp], gen, table[b]) + if schreier_gen not in stab_gens: + stab_gens.append(schreier_gen) + return [_af_new(x) for x in stab_gens] + + +PermGroup = PermutationGroup + + +class SymmetricPermutationGroup(Basic): + """ + The class defining the lazy form of SymmetricGroup. + + deg : int + + """ + def __new__(cls, deg): + deg = _sympify(deg) + obj = Basic.__new__(cls, deg) + return obj + + def __init__(self, *args, **kwargs): + self._deg = self.args[0] + self._order = None + + def __contains__(self, i): + """Return ``True`` if *i* is contained in SymmetricPermutationGroup. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, SymmetricPermutationGroup + >>> G = SymmetricPermutationGroup(4) + >>> Permutation(1, 2, 3) in G + True + + """ + if not isinstance(i, Permutation): + raise TypeError("A SymmetricPermutationGroup contains only Permutations as " + "elements, not elements of type %s" % type(i)) + return i.size == self.degree + + def order(self): + """ + Return the order of the SymmetricPermutationGroup. + + Examples + ======== + + >>> from sympy.combinatorics import SymmetricPermutationGroup + >>> G = SymmetricPermutationGroup(4) + >>> G.order() + 24 + """ + if self._order is not None: + return self._order + n = self._deg + self._order = factorial(n) + return self._order + + @property + def degree(self): + """ + Return the degree of the SymmetricPermutationGroup. + + Examples + ======== + + >>> from sympy.combinatorics import SymmetricPermutationGroup + >>> G = SymmetricPermutationGroup(4) + >>> G.degree + 4 + + """ + return self._deg + + @property + def identity(self): + ''' + Return the identity element of the SymmetricPermutationGroup. + + Examples + ======== + + >>> from sympy.combinatorics import SymmetricPermutationGroup + >>> G = SymmetricPermutationGroup(4) + >>> G.identity() + (3) + + ''' + return _af_new(list(range(self._deg))) + + +class Coset(Basic): + """A left coset of a permutation group with respect to an element. + + Parameters + ========== + + g : Permutation + + H : PermutationGroup + + dir : "+" or "-", If not specified by default it will be "+" + here ``dir`` specified the type of coset "+" represent the + right coset and "-" represent the left coset. + + G : PermutationGroup, optional + The group which contains *H* as its subgroup and *g* as its + element. + + If not specified, it would automatically become a symmetric + group ``SymmetricPermutationGroup(g.size)`` and + ``SymmetricPermutationGroup(H.degree)`` if ``g.size`` and ``H.degree`` + are matching.``SymmetricPermutationGroup`` is a lazy form of SymmetricGroup + used for representation purpose. + + """ + + def __new__(cls, g, H, G=None, dir="+"): + g = _sympify(g) + if not isinstance(g, Permutation): + raise NotImplementedError + + H = _sympify(H) + if not isinstance(H, PermutationGroup): + raise NotImplementedError + + if G is not None: + G = _sympify(G) + if not isinstance(G, (PermutationGroup, SymmetricPermutationGroup)): + raise NotImplementedError + if not H.is_subgroup(G): + raise ValueError("{} must be a subgroup of {}.".format(H, G)) + if g not in G: + raise ValueError("{} must be an element of {}.".format(g, G)) + else: + g_size = g.size + h_degree = H.degree + if g_size != h_degree: + raise ValueError( + "The size of the permutation {} and the degree of " + "the permutation group {} should be matching " + .format(g, H)) + G = SymmetricPermutationGroup(g.size) + + if isinstance(dir, str): + dir = Symbol(dir) + elif not isinstance(dir, Symbol): + raise TypeError("dir must be of type basestring or " + "Symbol, not %s" % type(dir)) + if str(dir) not in ('+', '-'): + raise ValueError("dir must be one of '+' or '-' not %s" % dir) + obj = Basic.__new__(cls, g, H, G, dir) + return obj + + def __init__(self, *args, **kwargs): + self._dir = self.args[3] + + @property + def is_left_coset(self): + """ + Check if the coset is left coset that is ``gH``. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup, Coset + >>> a = Permutation(1, 2) + >>> b = Permutation(0, 1) + >>> G = PermutationGroup([a, b]) + >>> cst = Coset(a, G, dir="-") + >>> cst.is_left_coset + True + + """ + return str(self._dir) == '-' + + @property + def is_right_coset(self): + """ + Check if the coset is right coset that is ``Hg``. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, PermutationGroup, Coset + >>> a = Permutation(1, 2) + >>> b = Permutation(0, 1) + >>> G = PermutationGroup([a, b]) + >>> cst = Coset(a, G, dir="+") + >>> cst.is_right_coset + True + + """ + return str(self._dir) == '+' + + def as_list(self): + """ + Return all the elements of coset in the form of list. + """ + g = self.args[0] + H = self.args[1] + cst = [] + if str(self._dir) == '+': + for h in H.elements: + cst.append(h*g) + else: + for h in H.elements: + cst.append(g*h) + return cst diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/permutations.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/permutations.py new file mode 100644 index 0000000000000000000000000000000000000000..3718467b69cbab2b9b0dd73e8fa160cceb0324bf --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/permutations.py @@ -0,0 +1,3114 @@ +import random +from collections import defaultdict +from collections.abc import Iterable +from functools import reduce + +from sympy.core.parameters import global_parameters +from sympy.core.basic import Atom +from sympy.core.expr import Expr +from sympy.core.numbers import int_valued +from sympy.core.numbers import Integer +from sympy.core.sympify import _sympify +from sympy.matrices import zeros +from sympy.polys.polytools import lcm +from sympy.printing.repr import srepr +from sympy.utilities.iterables import (flatten, has_variety, minlex, + has_dups, runs, is_sequence) +from sympy.utilities.misc import as_int +from mpmath.libmp.libintmath import ifac +from sympy.multipledispatch import dispatch + +def _af_rmul(a, b): + """ + Return the product b*a; input and output are array forms. The ith value + is a[b[i]]. + + Examples + ======== + + >>> from sympy.combinatorics.permutations import _af_rmul, Permutation + + >>> a, b = [1, 0, 2], [0, 2, 1] + >>> _af_rmul(a, b) + [1, 2, 0] + >>> [a[b[i]] for i in range(3)] + [1, 2, 0] + + This handles the operands in reverse order compared to the ``*`` operator: + + >>> a = Permutation(a) + >>> b = Permutation(b) + >>> list(a*b) + [2, 0, 1] + >>> [b(a(i)) for i in range(3)] + [2, 0, 1] + + See Also + ======== + + rmul, _af_rmuln + """ + return [a[i] for i in b] + + +def _af_rmuln(*abc): + """ + Given [a, b, c, ...] return the product of ...*c*b*a using array forms. + The ith value is a[b[c[i]]]. + + Examples + ======== + + >>> from sympy.combinatorics.permutations import _af_rmul, Permutation + + >>> a, b = [1, 0, 2], [0, 2, 1] + >>> _af_rmul(a, b) + [1, 2, 0] + >>> [a[b[i]] for i in range(3)] + [1, 2, 0] + + This handles the operands in reverse order compared to the ``*`` operator: + + >>> a = Permutation(a); b = Permutation(b) + >>> list(a*b) + [2, 0, 1] + >>> [b(a(i)) for i in range(3)] + [2, 0, 1] + + See Also + ======== + + rmul, _af_rmul + """ + a = abc + m = len(a) + if m == 3: + p0, p1, p2 = a + return [p0[p1[i]] for i in p2] + if m == 4: + p0, p1, p2, p3 = a + return [p0[p1[p2[i]]] for i in p3] + if m == 5: + p0, p1, p2, p3, p4 = a + return [p0[p1[p2[p3[i]]]] for i in p4] + if m == 6: + p0, p1, p2, p3, p4, p5 = a + return [p0[p1[p2[p3[p4[i]]]]] for i in p5] + if m == 7: + p0, p1, p2, p3, p4, p5, p6 = a + return [p0[p1[p2[p3[p4[p5[i]]]]]] for i in p6] + if m == 8: + p0, p1, p2, p3, p4, p5, p6, p7 = a + return [p0[p1[p2[p3[p4[p5[p6[i]]]]]]] for i in p7] + if m == 1: + return a[0][:] + if m == 2: + a, b = a + return [a[i] for i in b] + if m == 0: + raise ValueError("String must not be empty") + p0 = _af_rmuln(*a[:m//2]) + p1 = _af_rmuln(*a[m//2:]) + return [p0[i] for i in p1] + + +def _af_parity(pi): + """ + Computes the parity of a permutation in array form. + + Explanation + =========== + + The parity of a permutation reflects the parity of the + number of inversions in the permutation, i.e., the + number of pairs of x and y such that x > y but p[x] < p[y]. + + Examples + ======== + + >>> from sympy.combinatorics.permutations import _af_parity + >>> _af_parity([0, 1, 2, 3]) + 0 + >>> _af_parity([3, 2, 0, 1]) + 1 + + See Also + ======== + + Permutation + """ + n = len(pi) + a = [0] * n + c = 0 + for j in range(n): + if a[j] == 0: + c += 1 + a[j] = 1 + i = j + while pi[i] != j: + i = pi[i] + a[i] = 1 + return (n - c) % 2 + + +def _af_invert(a): + """ + Finds the inverse, ~A, of a permutation, A, given in array form. + + Examples + ======== + + >>> from sympy.combinatorics.permutations import _af_invert, _af_rmul + >>> A = [1, 2, 0, 3] + >>> _af_invert(A) + [2, 0, 1, 3] + >>> _af_rmul(_, A) + [0, 1, 2, 3] + + See Also + ======== + + Permutation, __invert__ + """ + inv_form = [0] * len(a) + for i, ai in enumerate(a): + inv_form[ai] = i + return inv_form + + +def _af_pow(a, n): + """ + Routine for finding powers of a permutation. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> from sympy.combinatorics.permutations import _af_pow + >>> p = Permutation([2, 0, 3, 1]) + >>> p.order() + 4 + >>> _af_pow(p._array_form, 4) + [0, 1, 2, 3] + """ + if n == 0: + return list(range(len(a))) + if n < 0: + return _af_pow(_af_invert(a), -n) + if n == 1: + return a[:] + elif n == 2: + b = [a[i] for i in a] + elif n == 3: + b = [a[a[i]] for i in a] + elif n == 4: + b = [a[a[a[i]]] for i in a] + else: + # use binary multiplication + b = list(range(len(a))) + while 1: + if n & 1: + b = [b[i] for i in a] + n -= 1 + if not n: + break + if n % 4 == 0: + a = [a[a[a[i]]] for i in a] + n = n // 4 + elif n % 2 == 0: + a = [a[i] for i in a] + n = n // 2 + return b + + +def _af_commutes_with(a, b): + """ + Checks if the two permutations with array forms + given by ``a`` and ``b`` commute. + + Examples + ======== + + >>> from sympy.combinatorics.permutations import _af_commutes_with + >>> _af_commutes_with([1, 2, 0], [0, 2, 1]) + False + + See Also + ======== + + Permutation, commutes_with + """ + return not any(a[b[i]] != b[a[i]] for i in range(len(a) - 1)) + + +class Cycle(dict): + """ + Wrapper around dict which provides the functionality of a disjoint cycle. + + Explanation + =========== + + A cycle shows the rule to use to move subsets of elements to obtain + a permutation. The Cycle class is more flexible than Permutation in + that 1) all elements need not be present in order to investigate how + multiple cycles act in sequence and 2) it can contain singletons: + + >>> from sympy.combinatorics.permutations import Perm, Cycle + + A Cycle will automatically parse a cycle given as a tuple on the rhs: + + >>> Cycle(1, 2)(2, 3) + (1 3 2) + + The identity cycle, Cycle(), can be used to start a product: + + >>> Cycle()(1, 2)(2, 3) + (1 3 2) + + The array form of a Cycle can be obtained by calling the list + method (or passing it to the list function) and all elements from + 0 will be shown: + + >>> a = Cycle(1, 2) + >>> a.list() + [0, 2, 1] + >>> list(a) + [0, 2, 1] + + If a larger (or smaller) range is desired use the list method and + provide the desired size -- but the Cycle cannot be truncated to + a size smaller than the largest element that is out of place: + + >>> b = Cycle(2, 4)(1, 2)(3, 1, 4)(1, 3) + >>> b.list() + [0, 2, 1, 3, 4] + >>> b.list(b.size + 1) + [0, 2, 1, 3, 4, 5] + >>> b.list(-1) + [0, 2, 1] + + Singletons are not shown when printing with one exception: the largest + element is always shown -- as a singleton if necessary: + + >>> Cycle(1, 4, 10)(4, 5) + (1 5 4 10) + >>> Cycle(1, 2)(4)(5)(10) + (1 2)(10) + + The array form can be used to instantiate a Permutation so other + properties of the permutation can be investigated: + + >>> Perm(Cycle(1, 2)(3, 4).list()).transpositions() + [(1, 2), (3, 4)] + + Notes + ===== + + The underlying structure of the Cycle is a dictionary and although + the __iter__ method has been redefined to give the array form of the + cycle, the underlying dictionary items are still available with the + such methods as items(): + + >>> list(Cycle(1, 2).items()) + [(1, 2), (2, 1)] + + See Also + ======== + + Permutation + """ + def __missing__(self, arg): + """Enter arg into dictionary and return arg.""" + return as_int(arg) + + def __iter__(self): + yield from self.list() + + def __call__(self, *other): + """Return product of cycles processed from R to L. + + Examples + ======== + + >>> from sympy.combinatorics import Cycle + >>> Cycle(1, 2)(2, 3) + (1 3 2) + + An instance of a Cycle will automatically parse list-like + objects and Permutations that are on the right. It is more + flexible than the Permutation in that all elements need not + be present: + + >>> a = Cycle(1, 2) + >>> a(2, 3) + (1 3 2) + >>> a(2, 3)(4, 5) + (1 3 2)(4 5) + + """ + rv = Cycle(*other) + for k, v in zip(list(self.keys()), [rv[self[k]] for k in self.keys()]): + rv[k] = v + return rv + + def list(self, size=None): + """Return the cycles as an explicit list starting from 0 up + to the greater of the largest value in the cycles and size. + + Truncation of trailing unmoved items will occur when size + is less than the maximum element in the cycle; if this is + desired, setting ``size=-1`` will guarantee such trimming. + + Examples + ======== + + >>> from sympy.combinatorics import Cycle + >>> p = Cycle(2, 3)(4, 5) + >>> p.list() + [0, 1, 3, 2, 5, 4] + >>> p.list(10) + [0, 1, 3, 2, 5, 4, 6, 7, 8, 9] + + Passing a length too small will trim trailing, unchanged elements + in the permutation: + + >>> Cycle(2, 4)(1, 2, 4).list(-1) + [0, 2, 1] + """ + if not self and size is None: + raise ValueError('must give size for empty Cycle') + if size is not None: + big = max([i for i in self.keys() if self[i] != i] + [0]) + size = max(size, big + 1) + else: + size = self.size + return [self[i] for i in range(size)] + + def __repr__(self): + """We want it to print as a Cycle, not as a dict. + + Examples + ======== + + >>> from sympy.combinatorics import Cycle + >>> Cycle(1, 2) + (1 2) + >>> print(_) + (1 2) + >>> list(Cycle(1, 2).items()) + [(1, 2), (2, 1)] + """ + if not self: + return 'Cycle()' + cycles = Permutation(self).cyclic_form + s = ''.join(str(tuple(c)) for c in cycles) + big = self.size - 1 + if not any(i == big for c in cycles for i in c): + s += '(%s)' % big + return 'Cycle%s' % s + + def __str__(self): + """We want it to be printed in a Cycle notation with no + comma in-between. + + Examples + ======== + + >>> from sympy.combinatorics import Cycle + >>> Cycle(1, 2) + (1 2) + >>> Cycle(1, 2, 4)(5, 6) + (1 2 4)(5 6) + """ + if not self: + return '()' + cycles = Permutation(self).cyclic_form + s = ''.join(str(tuple(c)) for c in cycles) + big = self.size - 1 + if not any(i == big for c in cycles for i in c): + s += '(%s)' % big + s = s.replace(',', '') + return s + + def __init__(self, *args): + """Load up a Cycle instance with the values for the cycle. + + Examples + ======== + + >>> from sympy.combinatorics import Cycle + >>> Cycle(1, 2, 6) + (1 2 6) + """ + + if not args: + return + if len(args) == 1: + if isinstance(args[0], Permutation): + for c in args[0].cyclic_form: + self.update(self(*c)) + return + elif isinstance(args[0], Cycle): + for k, v in args[0].items(): + self[k] = v + return + args = [as_int(a) for a in args] + if any(i < 0 for i in args): + raise ValueError('negative integers are not allowed in a cycle.') + if has_dups(args): + raise ValueError('All elements must be unique in a cycle.') + for i in range(-len(args), 0): + self[args[i]] = args[i + 1] + + @property + def size(self): + if not self: + return 0 + return max(self.keys()) + 1 + + def copy(self): + return Cycle(self) + + +class Permutation(Atom): + r""" + A permutation, alternatively known as an 'arrangement number' or 'ordering' + is an arrangement of the elements of an ordered list into a one-to-one + mapping with itself. The permutation of a given arrangement is given by + indicating the positions of the elements after re-arrangement [2]_. For + example, if one started with elements ``[x, y, a, b]`` (in that order) and + they were reordered as ``[x, y, b, a]`` then the permutation would be + ``[0, 1, 3, 2]``. Notice that (in SymPy) the first element is always referred + to as 0 and the permutation uses the indices of the elements in the + original ordering, not the elements ``(a, b, ...)`` themselves. + + >>> from sympy.combinatorics import Permutation + >>> from sympy import init_printing + >>> init_printing(perm_cyclic=False, pretty_print=False) + + Permutations Notation + ===================== + + Permutations are commonly represented in disjoint cycle or array forms. + + Array Notation and 2-line Form + ------------------------------------ + + In the 2-line form, the elements and their final positions are shown + as a matrix with 2 rows: + + [0 1 2 ... n-1] + [p(0) p(1) p(2) ... p(n-1)] + + Since the first line is always ``range(n)``, where n is the size of p, + it is sufficient to represent the permutation by the second line, + referred to as the "array form" of the permutation. This is entered + in brackets as the argument to the Permutation class: + + >>> p = Permutation([0, 2, 1]); p + Permutation([0, 2, 1]) + + Given i in range(p.size), the permutation maps i to i^p + + >>> [i^p for i in range(p.size)] + [0, 2, 1] + + The composite of two permutations p*q means first apply p, then q, so + i^(p*q) = (i^p)^q which is i^p^q according to Python precedence rules: + + >>> q = Permutation([2, 1, 0]) + >>> [i^p^q for i in range(3)] + [2, 0, 1] + >>> [i^(p*q) for i in range(3)] + [2, 0, 1] + + One can use also the notation p(i) = i^p, but then the composition + rule is (p*q)(i) = q(p(i)), not p(q(i)): + + >>> [(p*q)(i) for i in range(p.size)] + [2, 0, 1] + >>> [q(p(i)) for i in range(p.size)] + [2, 0, 1] + >>> [p(q(i)) for i in range(p.size)] + [1, 2, 0] + + Disjoint Cycle Notation + ----------------------- + + In disjoint cycle notation, only the elements that have shifted are + indicated. + + For example, [1, 3, 2, 0] can be represented as (0, 1, 3)(2). + This can be understood from the 2 line format of the given permutation. + In the 2-line form, + [0 1 2 3] + [1 3 2 0] + + The element in the 0th position is 1, so 0 -> 1. The element in the 1st + position is three, so 1 -> 3. And the element in the third position is again + 0, so 3 -> 0. Thus, 0 -> 1 -> 3 -> 0, and 2 -> 2. Thus, this can be represented + as 2 cycles: (0, 1, 3)(2). + In common notation, singular cycles are not explicitly written as they can be + inferred implicitly. + + Only the relative ordering of elements in a cycle matter: + + >>> Permutation(1,2,3) == Permutation(2,3,1) == Permutation(3,1,2) + True + + The disjoint cycle notation is convenient when representing + permutations that have several cycles in them: + + >>> Permutation(1, 2)(3, 5) == Permutation([[1, 2], [3, 5]]) + True + + It also provides some economy in entry when computing products of + permutations that are written in disjoint cycle notation: + + >>> Permutation(1, 2)(1, 3)(2, 3) + Permutation([0, 3, 2, 1]) + >>> _ == Permutation([[1, 2]])*Permutation([[1, 3]])*Permutation([[2, 3]]) + True + + Caution: when the cycles have common elements between them then the order + in which the permutations are applied matters. This module applies + the permutations from *left to right*. + + >>> Permutation(1, 2)(2, 3) == Permutation([(1, 2), (2, 3)]) + True + >>> Permutation(1, 2)(2, 3).list() + [0, 3, 1, 2] + + In the above case, (1,2) is computed before (2,3). + As 0 -> 0, 0 -> 0, element in position 0 is 0. + As 1 -> 2, 2 -> 3, element in position 1 is 3. + As 2 -> 1, 1 -> 1, element in position 2 is 1. + As 3 -> 3, 3 -> 2, element in position 3 is 2. + + If the first and second elements had been + swapped first, followed by the swapping of the second + and third, the result would have been [0, 2, 3, 1]. + If, you want to apply the cycles in the conventional + right to left order, call the function with arguments in reverse order + as demonstrated below: + + >>> Permutation([(1, 2), (2, 3)][::-1]).list() + [0, 2, 3, 1] + + Entering a singleton in a permutation is a way to indicate the size of the + permutation. The ``size`` keyword can also be used. + + Array-form entry: + + >>> Permutation([[1, 2], [9]]) + Permutation([0, 2, 1], size=10) + >>> Permutation([[1, 2]], size=10) + Permutation([0, 2, 1], size=10) + + Cyclic-form entry: + + >>> Permutation(1, 2, size=10) + Permutation([0, 2, 1], size=10) + >>> Permutation(9)(1, 2) + Permutation([0, 2, 1], size=10) + + Caution: no singleton containing an element larger than the largest + in any previous cycle can be entered. This is an important difference + in how Permutation and Cycle handle the ``__call__`` syntax. A singleton + argument at the start of a Permutation performs instantiation of the + Permutation and is permitted: + + >>> Permutation(5) + Permutation([], size=6) + + A singleton entered after instantiation is a call to the permutation + -- a function call -- and if the argument is out of range it will + trigger an error. For this reason, it is better to start the cycle + with the singleton: + + The following fails because there is no element 3: + + >>> Permutation(1, 2)(3) + Traceback (most recent call last): + ... + IndexError: list index out of range + + This is ok: only the call to an out of range singleton is prohibited; + otherwise the permutation autosizes: + + >>> Permutation(3)(1, 2) + Permutation([0, 2, 1, 3]) + >>> Permutation(1, 2)(3, 4) == Permutation(3, 4)(1, 2) + True + + + Equality testing + ---------------- + + The array forms must be the same in order for permutations to be equal: + + >>> Permutation([1, 0, 2, 3]) == Permutation([1, 0]) + False + + + Identity Permutation + -------------------- + + The identity permutation is a permutation in which no element is out of + place. It can be entered in a variety of ways. All the following create + an identity permutation of size 4: + + >>> I = Permutation([0, 1, 2, 3]) + >>> all(p == I for p in [ + ... Permutation(3), + ... Permutation(range(4)), + ... Permutation([], size=4), + ... Permutation(size=4)]) + True + + Watch out for entering the range *inside* a set of brackets (which is + cycle notation): + + >>> I == Permutation([range(4)]) + False + + + Permutation Printing + ==================== + + There are a few things to note about how Permutations are printed. + + .. deprecated:: 1.6 + + Configuring Permutation printing by setting + ``Permutation.print_cyclic`` is deprecated. Users should use the + ``perm_cyclic`` flag to the printers, as described below. + + 1) If you prefer one form (array or cycle) over another, you can set + ``init_printing`` with the ``perm_cyclic`` flag. + + >>> from sympy import init_printing + >>> p = Permutation(1, 2)(4, 5)(3, 4) + >>> p + Permutation([0, 2, 1, 4, 5, 3]) + + >>> init_printing(perm_cyclic=True, pretty_print=False) + >>> p + (1 2)(3 4 5) + + 2) Regardless of the setting, a list of elements in the array for cyclic + form can be obtained and either of those can be copied and supplied as + the argument to Permutation: + + >>> p.array_form + [0, 2, 1, 4, 5, 3] + >>> p.cyclic_form + [[1, 2], [3, 4, 5]] + >>> Permutation(_) == p + True + + 3) Printing is economical in that as little as possible is printed while + retaining all information about the size of the permutation: + + >>> init_printing(perm_cyclic=False, pretty_print=False) + >>> Permutation([1, 0, 2, 3]) + Permutation([1, 0, 2, 3]) + >>> Permutation([1, 0, 2, 3], size=20) + Permutation([1, 0], size=20) + >>> Permutation([1, 0, 2, 4, 3, 5, 6], size=20) + Permutation([1, 0, 2, 4, 3], size=20) + + >>> p = Permutation([1, 0, 2, 3]) + >>> init_printing(perm_cyclic=True, pretty_print=False) + >>> p + (3)(0 1) + >>> init_printing(perm_cyclic=False, pretty_print=False) + + The 2 was not printed but it is still there as can be seen with the + array_form and size methods: + + >>> p.array_form + [1, 0, 2, 3] + >>> p.size + 4 + + Short introduction to other methods + =================================== + + The permutation can act as a bijective function, telling what element is + located at a given position + + >>> q = Permutation([5, 2, 3, 4, 1, 0]) + >>> q.array_form[1] # the hard way + 2 + >>> q(1) # the easy way + 2 + >>> {i: q(i) for i in range(q.size)} # showing the bijection + {0: 5, 1: 2, 2: 3, 3: 4, 4: 1, 5: 0} + + The full cyclic form (including singletons) can be obtained: + + >>> p.full_cyclic_form + [[0, 1], [2], [3]] + + Any permutation can be factored into transpositions of pairs of elements: + + >>> Permutation([[1, 2], [3, 4, 5]]).transpositions() + [(1, 2), (3, 5), (3, 4)] + >>> Permutation.rmul(*[Permutation([ti], size=6) for ti in _]).cyclic_form + [[1, 2], [3, 4, 5]] + + The number of permutations on a set of n elements is given by n! and is + called the cardinality. + + >>> p.size + 4 + >>> p.cardinality + 24 + + A given permutation has a rank among all the possible permutations of the + same elements, but what that rank is depends on how the permutations are + enumerated. (There are a number of different methods of doing so.) The + lexicographic rank is given by the rank method and this rank is used to + increment a permutation with addition/subtraction: + + >>> p.rank() + 6 + >>> p + 1 + Permutation([1, 0, 3, 2]) + >>> p.next_lex() + Permutation([1, 0, 3, 2]) + >>> _.rank() + 7 + >>> p.unrank_lex(p.size, rank=7) + Permutation([1, 0, 3, 2]) + + The product of two permutations p and q is defined as their composition as + functions, (p*q)(i) = q(p(i)) [6]_. + + >>> p = Permutation([1, 0, 2, 3]) + >>> q = Permutation([2, 3, 1, 0]) + >>> list(q*p) + [2, 3, 0, 1] + >>> list(p*q) + [3, 2, 1, 0] + >>> [q(p(i)) for i in range(p.size)] + [3, 2, 1, 0] + + The permutation can be 'applied' to any list-like object, not only + Permutations: + + >>> p(['zero', 'one', 'four', 'two']) + ['one', 'zero', 'four', 'two'] + >>> p('zo42') + ['o', 'z', '4', '2'] + + If you have a list of arbitrary elements, the corresponding permutation + can be found with the from_sequence method: + + >>> Permutation.from_sequence('SymPy') + Permutation([1, 3, 2, 0, 4]) + + Checking if a Permutation is contained in a Group + ================================================= + + Generally if you have a group of permutations G on n symbols, and + you're checking if a permutation on less than n symbols is part + of that group, the check will fail. + + Here is an example for n=5 and we check if the cycle + (1,2,3) is in G: + + >>> from sympy import init_printing + >>> init_printing(perm_cyclic=True, pretty_print=False) + >>> from sympy.combinatorics import Cycle, Permutation + >>> from sympy.combinatorics.perm_groups import PermutationGroup + >>> G = PermutationGroup(Cycle(2, 3)(4, 5), Cycle(1, 2, 3, 4, 5)) + >>> p1 = Permutation(Cycle(2, 5, 3)) + >>> p2 = Permutation(Cycle(1, 2, 3)) + >>> a1 = Permutation(Cycle(1, 2, 3).list(6)) + >>> a2 = Permutation(Cycle(1, 2, 3)(5)) + >>> a3 = Permutation(Cycle(1, 2, 3),size=6) + >>> for p in [p1,p2,a1,a2,a3]: p, G.contains(p) + ((2 5 3), True) + ((1 2 3), False) + ((5)(1 2 3), True) + ((5)(1 2 3), True) + ((5)(1 2 3), True) + + The check for p2 above will fail. + + Checking if p1 is in G works because SymPy knows + G is a group on 5 symbols, and p1 is also on 5 symbols + (its largest element is 5). + + For ``a1``, the ``.list(6)`` call will extend the permutation to 5 + symbols, so the test will work as well. In the case of ``a2`` the + permutation is being extended to 5 symbols by using a singleton, + and in the case of ``a3`` it's extended through the constructor + argument ``size=6``. + + There is another way to do this, which is to tell the ``contains`` + method that the number of symbols the group is on does not need to + match perfectly the number of symbols for the permutation: + + >>> G.contains(p2,strict=False) + True + + This can be via the ``strict`` argument to the ``contains`` method, + and SymPy will try to extend the permutation on its own and then + perform the containment check. + + See Also + ======== + + Cycle + + References + ========== + + .. [1] Skiena, S. 'Permutations.' 1.1 in Implementing Discrete Mathematics + Combinatorics and Graph Theory with Mathematica. Reading, MA: + Addison-Wesley, pp. 3-16, 1990. + + .. [2] Knuth, D. E. The Art of Computer Programming, Vol. 4: Combinatorial + Algorithms, 1st ed. Reading, MA: Addison-Wesley, 2011. + + .. [3] Wendy Myrvold and Frank Ruskey. 2001. Ranking and unranking + permutations in linear time. Inf. Process. Lett. 79, 6 (September 2001), + 281-284. DOI=10.1016/S0020-0190(01)00141-7 + + .. [4] D. L. Kreher, D. R. Stinson 'Combinatorial Algorithms' + CRC Press, 1999 + + .. [5] Graham, R. L.; Knuth, D. E.; and Patashnik, O. + Concrete Mathematics: A Foundation for Computer Science, 2nd ed. + Reading, MA: Addison-Wesley, 1994. + + .. [6] https://en.wikipedia.org/w/index.php?oldid=499948155#Product_and_inverse + + .. [7] https://en.wikipedia.org/wiki/Lehmer_code + + """ + + is_Permutation = True + + _array_form = None + _cyclic_form = None + _cycle_structure = None + _size = None + _rank = None + + def __new__(cls, *args, size=None, **kwargs): + """ + Constructor for the Permutation object from a list or a + list of lists in which all elements of the permutation may + appear only once. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> from sympy import init_printing + >>> init_printing(perm_cyclic=False, pretty_print=False) + + Permutations entered in array-form are left unaltered: + + >>> Permutation([0, 2, 1]) + Permutation([0, 2, 1]) + + Permutations entered in cyclic form are converted to array form; + singletons need not be entered, but can be entered to indicate the + largest element: + + >>> Permutation([[4, 5, 6], [0, 1]]) + Permutation([1, 0, 2, 3, 5, 6, 4]) + >>> Permutation([[4, 5, 6], [0, 1], [19]]) + Permutation([1, 0, 2, 3, 5, 6, 4], size=20) + + All manipulation of permutations assumes that the smallest element + is 0 (in keeping with 0-based indexing in Python) so if the 0 is + missing when entering a permutation in array form, an error will be + raised: + + >>> Permutation([2, 1]) + Traceback (most recent call last): + ... + ValueError: Integers 0 through 2 must be present. + + If a permutation is entered in cyclic form, it can be entered without + singletons and the ``size`` specified so those values can be filled + in, otherwise the array form will only extend to the maximum value + in the cycles: + + >>> Permutation([[1, 4], [3, 5, 2]], size=10) + Permutation([0, 4, 3, 5, 1, 2], size=10) + >>> _.array_form + [0, 4, 3, 5, 1, 2, 6, 7, 8, 9] + """ + if size is not None: + size = int(size) + + #a) () + #b) (1) = identity + #c) (1, 2) = cycle + #d) ([1, 2, 3]) = array form + #e) ([[1, 2]]) = cyclic form + #f) (Cycle) = conversion to permutation + #g) (Permutation) = adjust size or return copy + ok = True + if not args: # a + return cls._af_new(list(range(size or 0))) + elif len(args) > 1: # c + return cls._af_new(Cycle(*args).list(size)) + if len(args) == 1: + a = args[0] + if isinstance(a, cls): # g + if size is None or size == a.size: + return a + return cls(a.array_form, size=size) + if isinstance(a, Cycle): # f + return cls._af_new(a.list(size)) + if not is_sequence(a): # b + if size is not None and a + 1 > size: + raise ValueError('size is too small when max is %s' % a) + return cls._af_new(list(range(a + 1))) + if has_variety(is_sequence(ai) for ai in a): + ok = False + else: + ok = False + if not ok: + raise ValueError("Permutation argument must be a list of ints, " + "a list of lists, Permutation or Cycle.") + + # safe to assume args are valid; this also makes a copy + # of the args + args = list(args[0]) + + is_cycle = args and is_sequence(args[0]) + if is_cycle: # e + args = [[int(i) for i in c] for c in args] + else: # d + args = [int(i) for i in args] + + # if there are n elements present, 0, 1, ..., n-1 should be present + # unless a cycle notation has been provided. A 0 will be added + # for convenience in case one wants to enter permutations where + # counting starts from 1. + + temp = flatten(args) + if has_dups(temp) and not is_cycle: + raise ValueError('there were repeated elements.') + temp = set(temp) + + if not is_cycle: + if temp != set(range(len(temp))): + raise ValueError('Integers 0 through %s must be present.' % + max(temp)) + if size is not None and temp and max(temp) + 1 > size: + raise ValueError('max element should not exceed %s' % (size - 1)) + + if is_cycle: + # it's not necessarily canonical so we won't store + # it -- use the array form instead + c = Cycle() + for ci in args: + c = c(*ci) + aform = c.list() + else: + aform = list(args) + if size and size > len(aform): + # don't allow for truncation of permutation which + # might split a cycle and lead to an invalid aform + # but do allow the permutation size to be increased + aform.extend(list(range(len(aform), size))) + + return cls._af_new(aform) + + @classmethod + def _af_new(cls, perm): + """A method to produce a Permutation object from a list; + the list is bound to the _array_form attribute, so it must + not be modified; this method is meant for internal use only; + the list ``a`` is supposed to be generated as a temporary value + in a method, so p = Perm._af_new(a) is the only object + to hold a reference to ``a``:: + + Examples + ======== + + >>> from sympy.combinatorics.permutations import Perm + >>> from sympy import init_printing + >>> init_printing(perm_cyclic=False, pretty_print=False) + >>> a = [2, 1, 3, 0] + >>> p = Perm._af_new(a) + >>> p + Permutation([2, 1, 3, 0]) + + """ + p = super().__new__(cls) + p._array_form = perm + p._size = len(perm) + return p + + def copy(self): + return self.__class__(self.array_form) + + def __getnewargs__(self): + return (self.array_form,) + + def _hashable_content(self): + # the array_form (a list) is the Permutation arg, so we need to + # return a tuple, instead + return tuple(self.array_form) + + @property + def array_form(self): + """ + Return a copy of the attribute _array_form + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation([[2, 0], [3, 1]]) + >>> p.array_form + [2, 3, 0, 1] + >>> Permutation([[2, 0, 3, 1]]).array_form + [3, 2, 0, 1] + >>> Permutation([2, 0, 3, 1]).array_form + [2, 0, 3, 1] + >>> Permutation([[1, 2], [4, 5]]).array_form + [0, 2, 1, 3, 5, 4] + """ + return self._array_form[:] + + def list(self, size=None): + """Return the permutation as an explicit list, possibly + trimming unmoved elements if size is less than the maximum + element in the permutation; if this is desired, setting + ``size=-1`` will guarantee such trimming. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation(2, 3)(4, 5) + >>> p.list() + [0, 1, 3, 2, 5, 4] + >>> p.list(10) + [0, 1, 3, 2, 5, 4, 6, 7, 8, 9] + + Passing a length too small will trim trailing, unchanged elements + in the permutation: + + >>> Permutation(2, 4)(1, 2, 4).list(-1) + [0, 2, 1] + >>> Permutation(3).list(-1) + [] + """ + if not self and size is None: + raise ValueError('must give size for empty Cycle') + rv = self.array_form + if size is not None: + if size > self.size: + rv.extend(list(range(self.size, size))) + else: + # find first value from rhs where rv[i] != i + i = self.size - 1 + while rv: + if rv[-1] != i: + break + rv.pop() + i -= 1 + return rv + + @property + def cyclic_form(self): + """ + This is used to convert to the cyclic notation + from the canonical notation. Singletons are omitted. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation([0, 3, 1, 2]) + >>> p.cyclic_form + [[1, 3, 2]] + >>> Permutation([1, 0, 2, 4, 3, 5]).cyclic_form + [[0, 1], [3, 4]] + + See Also + ======== + + array_form, full_cyclic_form + """ + if self._cyclic_form is not None: + return list(self._cyclic_form) + array_form = self.array_form + unchecked = [True] * len(array_form) + cyclic_form = [] + for i in range(len(array_form)): + if unchecked[i]: + cycle = [] + cycle.append(i) + unchecked[i] = False + j = i + while unchecked[array_form[j]]: + j = array_form[j] + cycle.append(j) + unchecked[j] = False + if len(cycle) > 1: + cyclic_form.append(cycle) + assert cycle == list(minlex(cycle)) + cyclic_form.sort() + self._cyclic_form = cyclic_form.copy() + return cyclic_form + + @property + def full_cyclic_form(self): + """Return permutation in cyclic form including singletons. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> Permutation([0, 2, 1]).full_cyclic_form + [[0], [1, 2]] + """ + need = set(range(self.size)) - set(flatten(self.cyclic_form)) + rv = self.cyclic_form + [[i] for i in need] + rv.sort() + return rv + + @property + def size(self): + """ + Returns the number of elements in the permutation. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> Permutation([[3, 2], [0, 1]]).size + 4 + + See Also + ======== + + cardinality, length, order, rank + """ + return self._size + + def support(self): + """Return the elements in permutation, P, for which P[i] != i. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation([[3, 2], [0, 1], [4]]) + >>> p.array_form + [1, 0, 3, 2, 4] + >>> p.support() + [0, 1, 2, 3] + """ + a = self.array_form + return [i for i, e in enumerate(a) if e != i] + + def __add__(self, other): + """Return permutation that is other higher in rank than self. + + The rank is the lexicographical rank, with the identity permutation + having rank of 0. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> I = Permutation([0, 1, 2, 3]) + >>> a = Permutation([2, 1, 3, 0]) + >>> I + a.rank() == a + True + + See Also + ======== + + __sub__, inversion_vector + + """ + rank = (self.rank() + other) % self.cardinality + rv = self.unrank_lex(self.size, rank) + rv._rank = rank + return rv + + def __sub__(self, other): + """Return the permutation that is other lower in rank than self. + + See Also + ======== + + __add__ + """ + return self.__add__(-other) + + @staticmethod + def rmul(*args): + """ + Return product of Permutations [a, b, c, ...] as the Permutation whose + ith value is a(b(c(i))). + + a, b, c, ... can be Permutation objects or tuples. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + + >>> a, b = [1, 0, 2], [0, 2, 1] + >>> a = Permutation(a); b = Permutation(b) + >>> list(Permutation.rmul(a, b)) + [1, 2, 0] + >>> [a(b(i)) for i in range(3)] + [1, 2, 0] + + This handles the operands in reverse order compared to the ``*`` operator: + + >>> a = Permutation(a); b = Permutation(b) + >>> list(a*b) + [2, 0, 1] + >>> [b(a(i)) for i in range(3)] + [2, 0, 1] + + Notes + ===== + + All items in the sequence will be parsed by Permutation as + necessary as long as the first item is a Permutation: + + >>> Permutation.rmul(a, [0, 2, 1]) == Permutation.rmul(a, b) + True + + The reverse order of arguments will raise a TypeError. + + """ + rv = args[0] + for i in range(1, len(args)): + rv = args[i]*rv + return rv + + @classmethod + def rmul_with_af(cls, *args): + """ + same as rmul, but the elements of args are Permutation objects + which have _array_form + """ + a = [x._array_form for x in args] + rv = cls._af_new(_af_rmuln(*a)) + return rv + + def mul_inv(self, other): + """ + other*~self, self and other have _array_form + """ + a = _af_invert(self._array_form) + b = other._array_form + return self._af_new(_af_rmul(a, b)) + + def __rmul__(self, other): + """This is needed to coerce other to Permutation in rmul.""" + cls = type(self) + return cls(other)*self + + def __mul__(self, other): + """ + Return the product a*b as a Permutation; the ith value is b(a(i)). + + Examples + ======== + + >>> from sympy.combinatorics.permutations import _af_rmul, Permutation + + >>> a, b = [1, 0, 2], [0, 2, 1] + >>> a = Permutation(a); b = Permutation(b) + >>> list(a*b) + [2, 0, 1] + >>> [b(a(i)) for i in range(3)] + [2, 0, 1] + + This handles operands in reverse order compared to _af_rmul and rmul: + + >>> al = list(a); bl = list(b) + >>> _af_rmul(al, bl) + [1, 2, 0] + >>> [al[bl[i]] for i in range(3)] + [1, 2, 0] + + It is acceptable for the arrays to have different lengths; the shorter + one will be padded to match the longer one: + + >>> from sympy import init_printing + >>> init_printing(perm_cyclic=False, pretty_print=False) + >>> b*Permutation([1, 0]) + Permutation([1, 2, 0]) + >>> Permutation([1, 0])*b + Permutation([2, 0, 1]) + + It is also acceptable to allow coercion to handle conversion of a + single list to the left of a Permutation: + + >>> [0, 1]*a # no change: 2-element identity + Permutation([1, 0, 2]) + >>> [[0, 1]]*a # exchange first two elements + Permutation([0, 1, 2]) + + You cannot use more than 1 cycle notation in a product of cycles + since coercion can only handle one argument to the left. To handle + multiple cycles it is convenient to use Cycle instead of Permutation: + + >>> [[1, 2]]*[[2, 3]]*Permutation([]) # doctest: +SKIP + >>> from sympy.combinatorics.permutations import Cycle + >>> Cycle(1, 2)(2, 3) + (1 3 2) + + """ + from sympy.combinatorics.perm_groups import PermutationGroup, Coset + if isinstance(other, PermutationGroup): + return Coset(self, other, dir='-') + a = self.array_form + # __rmul__ makes sure the other is a Permutation + b = other.array_form + if not b: + perm = a + else: + b.extend(list(range(len(b), len(a)))) + perm = [b[i] for i in a] + b[len(a):] + return self._af_new(perm) + + def commutes_with(self, other): + """ + Checks if the elements are commuting. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> a = Permutation([1, 4, 3, 0, 2, 5]) + >>> b = Permutation([0, 1, 2, 3, 4, 5]) + >>> a.commutes_with(b) + True + >>> b = Permutation([2, 3, 5, 4, 1, 0]) + >>> a.commutes_with(b) + False + """ + a = self.array_form + b = other.array_form + return _af_commutes_with(a, b) + + def __pow__(self, n): + """ + Routine for finding powers of a permutation. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> from sympy import init_printing + >>> init_printing(perm_cyclic=False, pretty_print=False) + >>> p = Permutation([2, 0, 3, 1]) + >>> p.order() + 4 + >>> p**4 + Permutation([0, 1, 2, 3]) + """ + if isinstance(n, Permutation): + raise NotImplementedError( + 'p**p is not defined; do you mean p^p (conjugate)?') + n = int(n) + return self._af_new(_af_pow(self.array_form, n)) + + def __rxor__(self, i): + """Return self(i) when ``i`` is an int. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation(1, 2, 9) + >>> 2^p == p(2) == 9 + True + """ + if int_valued(i): + return self(i) + else: + raise NotImplementedError( + "i^p = p(i) when i is an integer, not %s." % i) + + def __xor__(self, h): + """Return the conjugate permutation ``~h*self*h` `. + + Explanation + =========== + + If ``a`` and ``b`` are conjugates, ``a = h*b*~h`` and + ``b = ~h*a*h`` and both have the same cycle structure. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation(1, 2, 9) + >>> q = Permutation(6, 9, 8) + >>> p*q != q*p + True + + Calculate and check properties of the conjugate: + + >>> c = p^q + >>> c == ~q*p*q and p == q*c*~q + True + + The expression q^p^r is equivalent to q^(p*r): + + >>> r = Permutation(9)(4, 6, 8) + >>> q^p^r == q^(p*r) + True + + If the term to the left of the conjugate operator, i, is an integer + then this is interpreted as selecting the ith element from the + permutation to the right: + + >>> all(i^p == p(i) for i in range(p.size)) + True + + Note that the * operator as higher precedence than the ^ operator: + + >>> q^r*p^r == q^(r*p)^r == Permutation(9)(1, 6, 4) + True + + Notes + ===== + + In Python the precedence rule is p^q^r = (p^q)^r which differs + in general from p^(q^r) + + >>> q^p^r + (9)(1 4 8) + >>> q^(p^r) + (9)(1 8 6) + + For a given r and p, both of the following are conjugates of p: + ~r*p*r and r*p*~r. But these are not necessarily the same: + + >>> ~r*p*r == r*p*~r + True + + >>> p = Permutation(1, 2, 9)(5, 6) + >>> ~r*p*r == r*p*~r + False + + The conjugate ~r*p*r was chosen so that ``p^q^r`` would be equivalent + to ``p^(q*r)`` rather than ``p^(r*q)``. To obtain r*p*~r, pass ~r to + this method: + + >>> p^~r == r*p*~r + True + """ + + if self.size != h.size: + raise ValueError("The permutations must be of equal size.") + a = [None]*self.size + h = h._array_form + p = self._array_form + for i in range(self.size): + a[h[i]] = h[p[i]] + return self._af_new(a) + + def transpositions(self): + """ + Return the permutation decomposed into a list of transpositions. + + Explanation + =========== + + It is always possible to express a permutation as the product of + transpositions, see [1] + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation([[1, 2, 3], [0, 4, 5, 6, 7]]) + >>> t = p.transpositions() + >>> t + [(0, 7), (0, 6), (0, 5), (0, 4), (1, 3), (1, 2)] + >>> print(''.join(str(c) for c in t)) + (0, 7)(0, 6)(0, 5)(0, 4)(1, 3)(1, 2) + >>> Permutation.rmul(*[Permutation([ti], size=p.size) for ti in t]) == p + True + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Transposition_%28mathematics%29#Properties + + """ + a = self.cyclic_form + res = [] + for x in a: + nx = len(x) + if nx == 2: + res.append(tuple(x)) + elif nx > 2: + first = x[0] + res.extend((first, y) for y in x[nx - 1:0:-1]) + return res + + @classmethod + def from_sequence(self, i, key=None): + """Return the permutation needed to obtain ``i`` from the sorted + elements of ``i``. If custom sorting is desired, a key can be given. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + + >>> Permutation.from_sequence('SymPy') + (4)(0 1 3) + >>> _(sorted("SymPy")) + ['S', 'y', 'm', 'P', 'y'] + >>> Permutation.from_sequence('SymPy', key=lambda x: x.lower()) + (4)(0 2)(1 3) + """ + ic = list(zip(i, list(range(len(i))))) + if key: + ic.sort(key=lambda x: key(x[0])) + else: + ic.sort() + return ~Permutation([i[1] for i in ic]) + + def __invert__(self): + """ + Return the inverse of the permutation. + + A permutation multiplied by its inverse is the identity permutation. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> from sympy import init_printing + >>> init_printing(perm_cyclic=False, pretty_print=False) + >>> p = Permutation([[2, 0], [3, 1]]) + >>> ~p + Permutation([2, 3, 0, 1]) + >>> _ == p**-1 + True + >>> p*~p == ~p*p == Permutation([0, 1, 2, 3]) + True + """ + return self._af_new(_af_invert(self._array_form)) + + def __iter__(self): + """Yield elements from array form. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> list(Permutation(range(3))) + [0, 1, 2] + """ + yield from self.array_form + + def __repr__(self): + return srepr(self) + + def __call__(self, *i): + """ + Allows applying a permutation instance as a bijective function. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation([[2, 0], [3, 1]]) + >>> p.array_form + [2, 3, 0, 1] + >>> [p(i) for i in range(4)] + [2, 3, 0, 1] + + If an array is given then the permutation selects the items + from the array (i.e. the permutation is applied to the array): + + >>> from sympy.abc import x + >>> p([x, 1, 0, x**2]) + [0, x**2, x, 1] + """ + # list indices can be Integer or int; leave this + # as it is (don't test or convert it) because this + # gets called a lot and should be fast + if len(i) == 1: + i = i[0] + if not isinstance(i, Iterable): + i = as_int(i) + if i < 0 or i > self.size: + raise TypeError( + "{} should be an integer between 0 and {}" + .format(i, self.size-1)) + return self._array_form[i] + # P([a, b, c]) + if len(i) != self.size: + raise TypeError( + "{} should have the length {}.".format(i, self.size)) + return [i[j] for j in self._array_form] + # P(1, 2, 3) + return self*Permutation(Cycle(*i), size=self.size) + + def atoms(self): + """ + Returns all the elements of a permutation + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> Permutation([0, 1, 2, 3, 4, 5]).atoms() + {0, 1, 2, 3, 4, 5} + >>> Permutation([[0, 1], [2, 3], [4, 5]]).atoms() + {0, 1, 2, 3, 4, 5} + """ + return set(self.array_form) + + def apply(self, i): + r"""Apply the permutation to an expression. + + Parameters + ========== + + i : Expr + It should be an integer between $0$ and $n-1$ where $n$ + is the size of the permutation. + + If it is a symbol or a symbolic expression that can + have integer values, an ``AppliedPermutation`` object + will be returned which can represent an unevaluated + function. + + Notes + ===== + + Any permutation can be defined as a bijective function + $\sigma : \{ 0, 1, \dots, n-1 \} \rightarrow \{ 0, 1, \dots, n-1 \}$ + where $n$ denotes the size of the permutation. + + The definition may even be extended for any set with distinctive + elements, such that the permutation can even be applied for + real numbers or such, however, it is not implemented for now for + computational reasons and the integrity with the group theory + module. + + This function is similar to the ``__call__`` magic, however, + ``__call__`` magic already has some other applications like + permuting an array or attaching new cycles, which would + not always be mathematically consistent. + + This also guarantees that the return type is a SymPy integer, + which guarantees the safety to use assumptions. + """ + i = _sympify(i) + if i.is_integer is False: + raise NotImplementedError("{} should be an integer.".format(i)) + + n = self.size + if (i < 0) == True or (i >= n) == True: + raise NotImplementedError( + "{} should be an integer between 0 and {}".format(i, n-1)) + + if i.is_Integer: + return Integer(self._array_form[i]) + return AppliedPermutation(self, i) + + def next_lex(self): + """ + Returns the next permutation in lexicographical order. + If self is the last permutation in lexicographical order + it returns None. + See [4] section 2.4. + + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation([2, 3, 1, 0]) + >>> p = Permutation([2, 3, 1, 0]); p.rank() + 17 + >>> p = p.next_lex(); p.rank() + 18 + + See Also + ======== + + rank, unrank_lex + """ + perm = self.array_form[:] + n = len(perm) + i = n - 2 + while perm[i + 1] < perm[i]: + i -= 1 + if i == -1: + return None + else: + j = n - 1 + while perm[j] < perm[i]: + j -= 1 + perm[j], perm[i] = perm[i], perm[j] + i += 1 + j = n - 1 + while i < j: + perm[j], perm[i] = perm[i], perm[j] + i += 1 + j -= 1 + return self._af_new(perm) + + @classmethod + def unrank_nonlex(self, n, r): + """ + This is a linear time unranking algorithm that does not + respect lexicographic order [3]. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> from sympy import init_printing + >>> init_printing(perm_cyclic=False, pretty_print=False) + >>> Permutation.unrank_nonlex(4, 5) + Permutation([2, 0, 3, 1]) + >>> Permutation.unrank_nonlex(4, -1) + Permutation([0, 1, 2, 3]) + + See Also + ======== + + next_nonlex, rank_nonlex + """ + def _unrank1(n, r, a): + if n > 0: + a[n - 1], a[r % n] = a[r % n], a[n - 1] + _unrank1(n - 1, r//n, a) + + id_perm = list(range(n)) + n = int(n) + r = r % ifac(n) + _unrank1(n, r, id_perm) + return self._af_new(id_perm) + + def rank_nonlex(self, inv_perm=None): + """ + This is a linear time ranking algorithm that does not + enforce lexicographic order [3]. + + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation([0, 1, 2, 3]) + >>> p.rank_nonlex() + 23 + + See Also + ======== + + next_nonlex, unrank_nonlex + """ + def _rank1(n, perm, inv_perm): + if n == 1: + return 0 + s = perm[n - 1] + t = inv_perm[n - 1] + perm[n - 1], perm[t] = perm[t], s + inv_perm[n - 1], inv_perm[s] = inv_perm[s], t + return s + n*_rank1(n - 1, perm, inv_perm) + + if inv_perm is None: + inv_perm = (~self).array_form + if not inv_perm: + return 0 + perm = self.array_form[:] + r = _rank1(len(perm), perm, inv_perm) + return r + + def next_nonlex(self): + """ + Returns the next permutation in nonlex order [3]. + If self is the last permutation in this order it returns None. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> from sympy import init_printing + >>> init_printing(perm_cyclic=False, pretty_print=False) + >>> p = Permutation([2, 0, 3, 1]); p.rank_nonlex() + 5 + >>> p = p.next_nonlex(); p + Permutation([3, 0, 1, 2]) + >>> p.rank_nonlex() + 6 + + See Also + ======== + + rank_nonlex, unrank_nonlex + """ + r = self.rank_nonlex() + if r == ifac(self.size) - 1: + return None + return self.unrank_nonlex(self.size, r + 1) + + def rank(self): + """ + Returns the lexicographic rank of the permutation. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation([0, 1, 2, 3]) + >>> p.rank() + 0 + >>> p = Permutation([3, 2, 1, 0]) + >>> p.rank() + 23 + + See Also + ======== + + next_lex, unrank_lex, cardinality, length, order, size + """ + if self._rank is not None: + return self._rank + rank = 0 + rho = self.array_form[:] + n = self.size - 1 + size = n + 1 + psize = int(ifac(n)) + for j in range(size - 1): + rank += rho[j]*psize + for i in range(j + 1, size): + if rho[i] > rho[j]: + rho[i] -= 1 + psize //= n + n -= 1 + self._rank = rank + return rank + + @property + def cardinality(self): + """ + Returns the number of all possible permutations. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation([0, 1, 2, 3]) + >>> p.cardinality + 24 + + See Also + ======== + + length, order, rank, size + """ + return int(ifac(self.size)) + + def parity(self): + """ + Computes the parity of a permutation. + + Explanation + =========== + + The parity of a permutation reflects the parity of the + number of inversions in the permutation, i.e., the + number of pairs of x and y such that ``x > y`` but ``p[x] < p[y]``. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation([0, 1, 2, 3]) + >>> p.parity() + 0 + >>> p = Permutation([3, 2, 0, 1]) + >>> p.parity() + 1 + + See Also + ======== + + _af_parity + """ + if self._cyclic_form is not None: + return (self.size - self.cycles) % 2 + + return _af_parity(self.array_form) + + @property + def is_even(self): + """ + Checks if a permutation is even. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation([0, 1, 2, 3]) + >>> p.is_even + True + >>> p = Permutation([3, 2, 1, 0]) + >>> p.is_even + True + + See Also + ======== + + is_odd + """ + return not self.is_odd + + @property + def is_odd(self): + """ + Checks if a permutation is odd. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation([0, 1, 2, 3]) + >>> p.is_odd + False + >>> p = Permutation([3, 2, 0, 1]) + >>> p.is_odd + True + + See Also + ======== + + is_even + """ + return bool(self.parity() % 2) + + @property + def is_Singleton(self): + """ + Checks to see if the permutation contains only one number and is + thus the only possible permutation of this set of numbers + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> Permutation([0]).is_Singleton + True + >>> Permutation([0, 1]).is_Singleton + False + + See Also + ======== + + is_Empty + """ + return self.size == 1 + + @property + def is_Empty(self): + """ + Checks to see if the permutation is a set with zero elements + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> Permutation([]).is_Empty + True + >>> Permutation([0]).is_Empty + False + + See Also + ======== + + is_Singleton + """ + return self.size == 0 + + @property + def is_identity(self): + return self.is_Identity + + @property + def is_Identity(self): + """ + Returns True if the Permutation is an identity permutation. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation([]) + >>> p.is_Identity + True + >>> p = Permutation([[0], [1], [2]]) + >>> p.is_Identity + True + >>> p = Permutation([0, 1, 2]) + >>> p.is_Identity + True + >>> p = Permutation([0, 2, 1]) + >>> p.is_Identity + False + + See Also + ======== + + order + """ + af = self.array_form + return not af or all(i == af[i] for i in range(self.size)) + + def ascents(self): + """ + Returns the positions of ascents in a permutation, ie, the location + where p[i] < p[i+1] + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation([4, 0, 1, 3, 2]) + >>> p.ascents() + [1, 2] + + See Also + ======== + + descents, inversions, min, max + """ + a = self.array_form + pos = [i for i in range(len(a) - 1) if a[i] < a[i + 1]] + return pos + + def descents(self): + """ + Returns the positions of descents in a permutation, ie, the location + where p[i] > p[i+1] + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation([4, 0, 1, 3, 2]) + >>> p.descents() + [0, 3] + + See Also + ======== + + ascents, inversions, min, max + """ + a = self.array_form + pos = [i for i in range(len(a) - 1) if a[i] > a[i + 1]] + return pos + + def max(self) -> int: + """ + The maximum element moved by the permutation. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation([1, 0, 2, 3, 4]) + >>> p.max() + 1 + + See Also + ======== + + min, descents, ascents, inversions + """ + a = self.array_form + if not a: + return 0 + return max(_a for i, _a in enumerate(a) if _a != i) + + def min(self) -> int: + """ + The minimum element moved by the permutation. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation([0, 1, 4, 3, 2]) + >>> p.min() + 2 + + See Also + ======== + + max, descents, ascents, inversions + """ + a = self.array_form + if not a: + return 0 + return min(_a for i, _a in enumerate(a) if _a != i) + + def inversions(self): + """ + Computes the number of inversions of a permutation. + + Explanation + =========== + + An inversion is where i > j but p[i] < p[j]. + + For small length of p, it iterates over all i and j + values and calculates the number of inversions. + For large length of p, it uses a variation of merge + sort to calculate the number of inversions. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation([0, 1, 2, 3, 4, 5]) + >>> p.inversions() + 0 + >>> Permutation([3, 2, 1, 0]).inversions() + 6 + + See Also + ======== + + descents, ascents, min, max + + References + ========== + + .. [1] https://www.cp.eng.chula.ac.th/~prabhas//teaching/algo/algo2008/count-inv.htm + + """ + inversions = 0 + a = self.array_form + n = len(a) + if n < 130: + for i in range(n - 1): + b = a[i] + for c in a[i + 1:]: + if b > c: + inversions += 1 + else: + k = 1 + right = 0 + arr = a[:] + temp = a[:] + while k < n: + i = 0 + while i + k < n: + right = i + k * 2 - 1 + if right >= n: + right = n - 1 + inversions += _merge(arr, temp, i, i + k, right) + i = i + k * 2 + k = k * 2 + return inversions + + def commutator(self, x): + """Return the commutator of ``self`` and ``x``: ``~x*~self*x*self`` + + If f and g are part of a group, G, then the commutator of f and g + is the group identity iff f and g commute, i.e. fg == gf. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> from sympy import init_printing + >>> init_printing(perm_cyclic=False, pretty_print=False) + >>> p = Permutation([0, 2, 3, 1]) + >>> x = Permutation([2, 0, 3, 1]) + >>> c = p.commutator(x); c + Permutation([2, 1, 3, 0]) + >>> c == ~x*~p*x*p + True + + >>> I = Permutation(3) + >>> p = [I + i for i in range(6)] + >>> for i in range(len(p)): + ... for j in range(len(p)): + ... c = p[i].commutator(p[j]) + ... if p[i]*p[j] == p[j]*p[i]: + ... assert c == I + ... else: + ... assert c != I + ... + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Commutator + """ + + a = self.array_form + b = x.array_form + n = len(a) + if len(b) != n: + raise ValueError("The permutations must be of equal size.") + inva = [None]*n + for i in range(n): + inva[a[i]] = i + invb = [None]*n + for i in range(n): + invb[b[i]] = i + return self._af_new([a[b[inva[i]]] for i in invb]) + + def signature(self): + """ + Gives the signature of the permutation needed to place the + elements of the permutation in canonical order. + + The signature is calculated as (-1)^ + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation([0, 1, 2]) + >>> p.inversions() + 0 + >>> p.signature() + 1 + >>> q = Permutation([0,2,1]) + >>> q.inversions() + 1 + >>> q.signature() + -1 + + See Also + ======== + + inversions + """ + if self.is_even: + return 1 + return -1 + + def order(self): + """ + Computes the order of a permutation. + + When the permutation is raised to the power of its + order it equals the identity permutation. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> from sympy import init_printing + >>> init_printing(perm_cyclic=False, pretty_print=False) + >>> p = Permutation([3, 1, 5, 2, 4, 0]) + >>> p.order() + 4 + >>> (p**(p.order())) + Permutation([], size=6) + + See Also + ======== + + identity, cardinality, length, rank, size + """ + + return reduce(lcm, [len(cycle) for cycle in self.cyclic_form], 1) + + def length(self): + """ + Returns the number of integers moved by a permutation. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> Permutation([0, 3, 2, 1]).length() + 2 + >>> Permutation([[0, 1], [2, 3]]).length() + 4 + + See Also + ======== + + min, max, support, cardinality, order, rank, size + """ + + return len(self.support()) + + @property + def cycle_structure(self): + """Return the cycle structure of the permutation as a dictionary + indicating the multiplicity of each cycle length. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> Permutation(3).cycle_structure + {1: 4} + >>> Permutation(0, 4, 3)(1, 2)(5, 6).cycle_structure + {2: 2, 3: 1} + """ + if self._cycle_structure: + rv = self._cycle_structure + else: + rv = defaultdict(int) + singletons = self.size + for c in self.cyclic_form: + rv[len(c)] += 1 + singletons -= len(c) + if singletons: + rv[1] = singletons + self._cycle_structure = rv + return dict(rv) # make a copy + + @property + def cycles(self): + """ + Returns the number of cycles contained in the permutation + (including singletons). + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> Permutation([0, 1, 2]).cycles + 3 + >>> Permutation([0, 1, 2]).full_cyclic_form + [[0], [1], [2]] + >>> Permutation(0, 1)(2, 3).cycles + 2 + + See Also + ======== + sympy.functions.combinatorial.numbers.stirling + """ + return len(self.full_cyclic_form) + + def index(self): + """ + Returns the index of a permutation. + + The index of a permutation is the sum of all subscripts j such + that p[j] is greater than p[j+1]. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation([3, 0, 2, 1, 4]) + >>> p.index() + 2 + """ + a = self.array_form + + return sum(j for j in range(len(a) - 1) if a[j] > a[j + 1]) + + def runs(self): + """ + Returns the runs of a permutation. + + An ascending sequence in a permutation is called a run [5]. + + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation([2, 5, 7, 3, 6, 0, 1, 4, 8]) + >>> p.runs() + [[2, 5, 7], [3, 6], [0, 1, 4, 8]] + >>> q = Permutation([1,3,2,0]) + >>> q.runs() + [[1, 3], [2], [0]] + """ + return runs(self.array_form) + + def inversion_vector(self): + """Return the inversion vector of the permutation. + + The inversion vector consists of elements whose value + indicates the number of elements in the permutation + that are lesser than it and lie on its right hand side. + + The inversion vector is the same as the Lehmer encoding of a + permutation. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation([4, 8, 0, 7, 1, 5, 3, 6, 2]) + >>> p.inversion_vector() + [4, 7, 0, 5, 0, 2, 1, 1] + >>> p = Permutation([3, 2, 1, 0]) + >>> p.inversion_vector() + [3, 2, 1] + + The inversion vector increases lexicographically with the rank + of the permutation, the -ith element cycling through 0..i. + + >>> p = Permutation(2) + >>> while p: + ... print('%s %s %s' % (p, p.inversion_vector(), p.rank())) + ... p = p.next_lex() + (2) [0, 0] 0 + (1 2) [0, 1] 1 + (2)(0 1) [1, 0] 2 + (0 1 2) [1, 1] 3 + (0 2 1) [2, 0] 4 + (0 2) [2, 1] 5 + + See Also + ======== + + from_inversion_vector + """ + self_array_form = self.array_form + n = len(self_array_form) + inversion_vector = [0] * (n - 1) + + for i in range(n - 1): + val = 0 + for j in range(i + 1, n): + if self_array_form[j] < self_array_form[i]: + val += 1 + inversion_vector[i] = val + return inversion_vector + + def rank_trotterjohnson(self): + """ + Returns the Trotter Johnson rank, which we get from the minimal + change algorithm. See [4] section 2.4. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation([0, 1, 2, 3]) + >>> p.rank_trotterjohnson() + 0 + >>> p = Permutation([0, 2, 1, 3]) + >>> p.rank_trotterjohnson() + 7 + + See Also + ======== + + unrank_trotterjohnson, next_trotterjohnson + """ + if self.array_form == [] or self.is_Identity: + return 0 + if self.array_form == [1, 0]: + return 1 + perm = self.array_form + n = self.size + rank = 0 + for j in range(1, n): + k = 1 + i = 0 + while perm[i] != j: + if perm[i] < j: + k += 1 + i += 1 + j1 = j + 1 + if rank % 2 == 0: + rank = j1*rank + j1 - k + else: + rank = j1*rank + k - 1 + return rank + + @classmethod + def unrank_trotterjohnson(cls, size, rank): + """ + Trotter Johnson permutation unranking. See [4] section 2.4. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> from sympy import init_printing + >>> init_printing(perm_cyclic=False, pretty_print=False) + >>> Permutation.unrank_trotterjohnson(5, 10) + Permutation([0, 3, 1, 2, 4]) + + See Also + ======== + + rank_trotterjohnson, next_trotterjohnson + """ + perm = [0]*size + r2 = 0 + n = ifac(size) + pj = 1 + for j in range(2, size + 1): + pj *= j + r1 = (rank * pj) // n + k = r1 - j*r2 + if r2 % 2 == 0: + for i in range(j - 1, j - k - 1, -1): + perm[i] = perm[i - 1] + perm[j - k - 1] = j - 1 + else: + for i in range(j - 1, k, -1): + perm[i] = perm[i - 1] + perm[k] = j - 1 + r2 = r1 + return cls._af_new(perm) + + def next_trotterjohnson(self): + """ + Returns the next permutation in Trotter-Johnson order. + If self is the last permutation it returns None. + See [4] section 2.4. If it is desired to generate all such + permutations, they can be generated in order more quickly + with the ``generate_bell`` function. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> from sympy import init_printing + >>> init_printing(perm_cyclic=False, pretty_print=False) + >>> p = Permutation([3, 0, 2, 1]) + >>> p.rank_trotterjohnson() + 4 + >>> p = p.next_trotterjohnson(); p + Permutation([0, 3, 2, 1]) + >>> p.rank_trotterjohnson() + 5 + + See Also + ======== + + rank_trotterjohnson, unrank_trotterjohnson, sympy.utilities.iterables.generate_bell + """ + pi = self.array_form[:] + n = len(pi) + st = 0 + rho = pi[:] + done = False + m = n-1 + while m > 0 and not done: + d = rho.index(m) + for i in range(d, m): + rho[i] = rho[i + 1] + par = _af_parity(rho[:m]) + if par == 1: + if d == m: + m -= 1 + else: + pi[st + d], pi[st + d + 1] = pi[st + d + 1], pi[st + d] + done = True + else: + if d == 0: + m -= 1 + st += 1 + else: + pi[st + d], pi[st + d - 1] = pi[st + d - 1], pi[st + d] + done = True + if m == 0: + return None + return self._af_new(pi) + + def get_precedence_matrix(self): + """ + Gets the precedence matrix. This is used for computing the + distance between two permutations. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> from sympy import init_printing + >>> init_printing(perm_cyclic=False, pretty_print=False) + >>> p = Permutation.josephus(3, 6, 1) + >>> p + Permutation([2, 5, 3, 1, 4, 0]) + >>> p.get_precedence_matrix() + Matrix([ + [0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 1, 0], + [1, 1, 0, 1, 1, 1], + [1, 1, 0, 0, 1, 0], + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 1, 1, 0]]) + + See Also + ======== + + get_precedence_distance, get_adjacency_matrix, get_adjacency_distance + """ + m = zeros(self.size) + perm = self.array_form + for i in range(m.rows): + for j in range(i + 1, m.cols): + m[perm[i], perm[j]] = 1 + return m + + def get_precedence_distance(self, other): + """ + Computes the precedence distance between two permutations. + + Explanation + =========== + + Suppose p and p' represent n jobs. The precedence metric + counts the number of times a job j is preceded by job i + in both p and p'. This metric is commutative. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation([2, 0, 4, 3, 1]) + >>> q = Permutation([3, 1, 2, 4, 0]) + >>> p.get_precedence_distance(q) + 7 + >>> q.get_precedence_distance(p) + 7 + + See Also + ======== + + get_precedence_matrix, get_adjacency_matrix, get_adjacency_distance + """ + if self.size != other.size: + raise ValueError("The permutations must be of equal size.") + self_prec_mat = self.get_precedence_matrix() + other_prec_mat = other.get_precedence_matrix() + n_prec = 0 + for i in range(self.size): + for j in range(self.size): + if i == j: + continue + if self_prec_mat[i, j] * other_prec_mat[i, j] == 1: + n_prec += 1 + d = self.size * (self.size - 1)//2 - n_prec + return d + + def get_adjacency_matrix(self): + """ + Computes the adjacency matrix of a permutation. + + Explanation + =========== + + If job i is adjacent to job j in a permutation p + then we set m[i, j] = 1 where m is the adjacency + matrix of p. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation.josephus(3, 6, 1) + >>> p.get_adjacency_matrix() + Matrix([ + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 1], + [0, 1, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0]]) + >>> q = Permutation([0, 1, 2, 3]) + >>> q.get_adjacency_matrix() + Matrix([ + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1], + [0, 0, 0, 0]]) + + See Also + ======== + + get_precedence_matrix, get_precedence_distance, get_adjacency_distance + """ + m = zeros(self.size) + perm = self.array_form + for i in range(self.size - 1): + m[perm[i], perm[i + 1]] = 1 + return m + + def get_adjacency_distance(self, other): + """ + Computes the adjacency distance between two permutations. + + Explanation + =========== + + This metric counts the number of times a pair i,j of jobs is + adjacent in both p and p'. If n_adj is this quantity then + the adjacency distance is n - n_adj - 1 [1] + + [1] Reeves, Colin R. Landscapes, Operators and Heuristic search, Annals + of Operational Research, 86, pp 473-490. (1999) + + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation([0, 3, 1, 2, 4]) + >>> q = Permutation.josephus(4, 5, 2) + >>> p.get_adjacency_distance(q) + 3 + >>> r = Permutation([0, 2, 1, 4, 3]) + >>> p.get_adjacency_distance(r) + 4 + + See Also + ======== + + get_precedence_matrix, get_precedence_distance, get_adjacency_matrix + """ + if self.size != other.size: + raise ValueError("The permutations must be of the same size.") + self_adj_mat = self.get_adjacency_matrix() + other_adj_mat = other.get_adjacency_matrix() + n_adj = 0 + for i in range(self.size): + for j in range(self.size): + if i == j: + continue + if self_adj_mat[i, j] * other_adj_mat[i, j] == 1: + n_adj += 1 + d = self.size - n_adj - 1 + return d + + def get_positional_distance(self, other): + """ + Computes the positional distance between two permutations. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> p = Permutation([0, 3, 1, 2, 4]) + >>> q = Permutation.josephus(4, 5, 2) + >>> r = Permutation([3, 1, 4, 0, 2]) + >>> p.get_positional_distance(q) + 12 + >>> p.get_positional_distance(r) + 12 + + See Also + ======== + + get_precedence_distance, get_adjacency_distance + """ + a = self.array_form + b = other.array_form + if len(a) != len(b): + raise ValueError("The permutations must be of the same size.") + return sum(abs(a[i] - b[i]) for i in range(len(a))) + + @classmethod + def josephus(cls, m, n, s=1): + """Return as a permutation the shuffling of range(n) using the Josephus + scheme in which every m-th item is selected until all have been chosen. + The returned permutation has elements listed by the order in which they + were selected. + + The parameter ``s`` stops the selection process when there are ``s`` + items remaining and these are selected by continuing the selection, + counting by 1 rather than by ``m``. + + Consider selecting every 3rd item from 6 until only 2 remain:: + + choices chosen + ======== ====== + 012345 + 01 345 2 + 01 34 25 + 01 4 253 + 0 4 2531 + 0 25314 + 253140 + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> Permutation.josephus(3, 6, 2).array_form + [2, 5, 3, 1, 4, 0] + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Flavius_Josephus + .. [2] https://en.wikipedia.org/wiki/Josephus_problem + .. [3] https://web.archive.org/web/20171008094331/http://www.wou.edu/~burtonl/josephus.html + + """ + from collections import deque + m -= 1 + Q = deque(list(range(n))) + perm = [] + while len(Q) > max(s, 1): + for dp in range(m): + Q.append(Q.popleft()) + perm.append(Q.popleft()) + perm.extend(list(Q)) + return cls(perm) + + @classmethod + def from_inversion_vector(cls, inversion): + """ + Calculates the permutation from the inversion vector. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> from sympy import init_printing + >>> init_printing(perm_cyclic=False, pretty_print=False) + >>> Permutation.from_inversion_vector([3, 2, 1, 0, 0]) + Permutation([3, 2, 1, 0, 4, 5]) + + """ + size = len(inversion) + N = list(range(size + 1)) + perm = [] + try: + for k in range(size): + val = N[inversion[k]] + perm.append(val) + N.remove(val) + except IndexError: + raise ValueError("The inversion vector is not valid.") + perm.extend(N) + return cls._af_new(perm) + + @classmethod + def random(cls, n): + """ + Generates a random permutation of length ``n``. + + Uses the underlying Python pseudo-random number generator. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> Permutation.random(2) in (Permutation([1, 0]), Permutation([0, 1])) + True + + """ + perm_array = list(range(n)) + random.shuffle(perm_array) + return cls._af_new(perm_array) + + @classmethod + def unrank_lex(cls, size, rank): + """ + Lexicographic permutation unranking. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> from sympy import init_printing + >>> init_printing(perm_cyclic=False, pretty_print=False) + >>> a = Permutation.unrank_lex(5, 10) + >>> a.rank() + 10 + >>> a + Permutation([0, 2, 4, 1, 3]) + + See Also + ======== + + rank, next_lex + """ + perm_array = [0] * size + psize = 1 + for i in range(size): + new_psize = psize*(i + 1) + d = (rank % new_psize) // psize + rank -= d*psize + perm_array[size - i - 1] = d + for j in range(size - i, size): + if perm_array[j] > d - 1: + perm_array[j] += 1 + psize = new_psize + return cls._af_new(perm_array) + + def resize(self, n): + """Resize the permutation to the new size ``n``. + + Parameters + ========== + + n : int + The new size of the permutation. + + Raises + ====== + + ValueError + If the permutation cannot be resized to the given size. + This may only happen when resized to a smaller size than + the original. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + + Increasing the size of a permutation: + + >>> p = Permutation(0, 1, 2) + >>> p = p.resize(5) + >>> p + (4)(0 1 2) + + Decreasing the size of the permutation: + + >>> p = p.resize(4) + >>> p + (3)(0 1 2) + + If resizing to the specific size breaks the cycles: + + >>> p.resize(2) + Traceback (most recent call last): + ... + ValueError: The permutation cannot be resized to 2 because the + cycle (0, 1, 2) may break. + """ + aform = self.array_form + l = len(aform) + if n > l: + aform += list(range(l, n)) + return Permutation._af_new(aform) + + elif n < l: + cyclic_form = self.full_cyclic_form + new_cyclic_form = [] + for cycle in cyclic_form: + cycle_min = min(cycle) + cycle_max = max(cycle) + if cycle_min <= n-1: + if cycle_max > n-1: + raise ValueError( + "The permutation cannot be resized to {} " + "because the cycle {} may break." + .format(n, tuple(cycle))) + + new_cyclic_form.append(cycle) + return Permutation(new_cyclic_form) + + return self + + # XXX Deprecated flag + print_cyclic = None + + +def _merge(arr, temp, left, mid, right): + """ + Merges two sorted arrays and calculates the inversion count. + + Helper function for calculating inversions. This method is + for internal use only. + """ + i = k = left + j = mid + inv_count = 0 + while i < mid and j <= right: + if arr[i] < arr[j]: + temp[k] = arr[i] + k += 1 + i += 1 + else: + temp[k] = arr[j] + k += 1 + j += 1 + inv_count += (mid -i) + while i < mid: + temp[k] = arr[i] + k += 1 + i += 1 + if j <= right: + k += right - j + 1 + j += right - j + 1 + arr[left:k + 1] = temp[left:k + 1] + else: + arr[left:right + 1] = temp[left:right + 1] + return inv_count + +Perm = Permutation +_af_new = Perm._af_new + + +class AppliedPermutation(Expr): + """A permutation applied to a symbolic variable. + + Parameters + ========== + + perm : Permutation + x : Expr + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.combinatorics import Permutation + + Creating a symbolic permutation function application: + + >>> x = Symbol('x') + >>> p = Permutation(0, 1, 2) + >>> p.apply(x) + AppliedPermutation((0 1 2), x) + >>> _.subs(x, 1) + 2 + """ + def __new__(cls, perm, x, evaluate=None): + if evaluate is None: + evaluate = global_parameters.evaluate + + perm = _sympify(perm) + x = _sympify(x) + + if not isinstance(perm, Permutation): + raise ValueError("{} must be a Permutation instance." + .format(perm)) + + if evaluate: + if x.is_Integer: + return perm.apply(x) + + obj = super().__new__(cls, perm, x) + return obj + + +@dispatch(Permutation, Permutation) +def _eval_is_eq(lhs, rhs): + if lhs._size != rhs._size: + return None + return lhs._array_form == rhs._array_form diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/polyhedron.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/polyhedron.py new file mode 100644 index 0000000000000000000000000000000000000000..2bc05d7d97c840649661f3290442499c841ca7c1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/polyhedron.py @@ -0,0 +1,1019 @@ +from sympy.combinatorics import Permutation as Perm +from sympy.combinatorics.perm_groups import PermutationGroup +from sympy.core import Basic, Tuple, default_sort_key +from sympy.sets import FiniteSet +from sympy.utilities.iterables import (minlex, unflatten, flatten) +from sympy.utilities.misc import as_int + +rmul = Perm.rmul + + +class Polyhedron(Basic): + """ + Represents the polyhedral symmetry group (PSG). + + Explanation + =========== + + The PSG is one of the symmetry groups of the Platonic solids. + There are three polyhedral groups: the tetrahedral group + of order 12, the octahedral group of order 24, and the + icosahedral group of order 60. + + All doctests have been given in the docstring of the + constructor of the object. + + References + ========== + + .. [1] https://mathworld.wolfram.com/PolyhedralGroup.html + + """ + _edges = None + + def __new__(cls, corners, faces=(), pgroup=()): + """ + The constructor of the Polyhedron group object. + + Explanation + =========== + + It takes up to three parameters: the corners, faces, and + allowed transformations. + + The corners/vertices are entered as a list of arbitrary + expressions that are used to identify each vertex. + + The faces are entered as a list of tuples of indices; a tuple + of indices identifies the vertices which define the face. They + should be entered in a cw or ccw order; they will be standardized + by reversal and rotation to be give the lowest lexical ordering. + If no faces are given then no edges will be computed. + + >>> from sympy.combinatorics.polyhedron import Polyhedron + >>> Polyhedron(list('abc'), [(1, 2, 0)]).faces + {(0, 1, 2)} + >>> Polyhedron(list('abc'), [(1, 0, 2)]).faces + {(0, 1, 2)} + + The allowed transformations are entered as allowable permutations + of the vertices for the polyhedron. Instance of Permutations + (as with faces) should refer to the supplied vertices by index. + These permutation are stored as a PermutationGroup. + + Examples + ======== + + >>> from sympy.combinatorics.permutations import Permutation + >>> from sympy import init_printing + >>> from sympy.abc import w, x, y, z + >>> init_printing(pretty_print=False, perm_cyclic=False) + + Here we construct the Polyhedron object for a tetrahedron. + + >>> corners = [w, x, y, z] + >>> faces = [(0, 1, 2), (0, 2, 3), (0, 3, 1), (1, 2, 3)] + + Next, allowed transformations of the polyhedron must be given. This + is given as permutations of vertices. + + Although the vertices of a tetrahedron can be numbered in 24 (4!) + different ways, there are only 12 different orientations for a + physical tetrahedron. The following permutations, applied once or + twice, will generate all 12 of the orientations. (The identity + permutation, Permutation(range(4)), is not included since it does + not change the orientation of the vertices.) + + >>> pgroup = [Permutation([[0, 1, 2], [3]]), \ + Permutation([[0, 1, 3], [2]]), \ + Permutation([[0, 2, 3], [1]]), \ + Permutation([[1, 2, 3], [0]]), \ + Permutation([[0, 1], [2, 3]]), \ + Permutation([[0, 2], [1, 3]]), \ + Permutation([[0, 3], [1, 2]])] + + The Polyhedron is now constructed and demonstrated: + + >>> tetra = Polyhedron(corners, faces, pgroup) + >>> tetra.size + 4 + >>> tetra.edges + {(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)} + >>> tetra.corners + (w, x, y, z) + + It can be rotated with an arbitrary permutation of vertices, e.g. + the following permutation is not in the pgroup: + + >>> tetra.rotate(Permutation([0, 1, 3, 2])) + >>> tetra.corners + (w, x, z, y) + + An allowed permutation of the vertices can be constructed by + repeatedly applying permutations from the pgroup to the vertices. + Here is a demonstration that applying p and p**2 for every p in + pgroup generates all the orientations of a tetrahedron and no others: + + >>> all = ( (w, x, y, z), \ + (x, y, w, z), \ + (y, w, x, z), \ + (w, z, x, y), \ + (z, w, y, x), \ + (w, y, z, x), \ + (y, z, w, x), \ + (x, z, y, w), \ + (z, y, x, w), \ + (y, x, z, w), \ + (x, w, z, y), \ + (z, x, w, y) ) + + >>> got = [] + >>> for p in (pgroup + [p**2 for p in pgroup]): + ... h = Polyhedron(corners) + ... h.rotate(p) + ... got.append(h.corners) + ... + >>> set(got) == set(all) + True + + The make_perm method of a PermutationGroup will randomly pick + permutations, multiply them together, and return the permutation that + can be applied to the polyhedron to give the orientation produced + by those individual permutations. + + Here, 3 permutations are used: + + >>> tetra.pgroup.make_perm(3) # doctest: +SKIP + Permutation([0, 3, 1, 2]) + + To select the permutations that should be used, supply a list + of indices to the permutations in pgroup in the order they should + be applied: + + >>> use = [0, 0, 2] + >>> p002 = tetra.pgroup.make_perm(3, use) + >>> p002 + Permutation([1, 0, 3, 2]) + + + Apply them one at a time: + + >>> tetra.reset() + >>> for i in use: + ... tetra.rotate(pgroup[i]) + ... + >>> tetra.vertices + (x, w, z, y) + >>> sequentially = tetra.vertices + + Apply the composite permutation: + + >>> tetra.reset() + >>> tetra.rotate(p002) + >>> tetra.corners + (x, w, z, y) + >>> tetra.corners in all and tetra.corners == sequentially + True + + Notes + ===== + + Defining permutation groups + --------------------------- + + It is not necessary to enter any permutations, nor is necessary to + enter a complete set of transformations. In fact, for a polyhedron, + all configurations can be constructed from just two permutations. + For example, the orientations of a tetrahedron can be generated from + an axis passing through a vertex and face and another axis passing + through a different vertex or from an axis passing through the + midpoints of two edges opposite of each other. + + For simplicity of presentation, consider a square -- + not a cube -- with vertices 1, 2, 3, and 4: + + 1-----2 We could think of axes of rotation being: + | | 1) through the face + | | 2) from midpoint 1-2 to 3-4 or 1-3 to 2-4 + 3-----4 3) lines 1-4 or 2-3 + + + To determine how to write the permutations, imagine 4 cameras, + one at each corner, labeled A-D: + + A B A B + 1-----2 1-----3 vertex index: + | | | | 1 0 + | | | | 2 1 + 3-----4 2-----4 3 2 + C D C D 4 3 + + original after rotation + along 1-4 + + A diagonal and a face axis will be chosen for the "permutation group" + from which any orientation can be constructed. + + >>> pgroup = [] + + Imagine a clockwise rotation when viewing 1-4 from camera A. The new + orientation is (in camera-order): 1, 3, 2, 4 so the permutation is + given using the *indices* of the vertices as: + + >>> pgroup.append(Permutation((0, 2, 1, 3))) + + Now imagine rotating clockwise when looking down an axis entering the + center of the square as viewed. The new camera-order would be + 3, 1, 4, 2 so the permutation is (using indices): + + >>> pgroup.append(Permutation((2, 0, 3, 1))) + + The square can now be constructed: + ** use real-world labels for the vertices, entering them in + camera order + ** for the faces we use zero-based indices of the vertices + in *edge-order* as the face is traversed; neither the + direction nor the starting point matter -- the faces are + only used to define edges (if so desired). + + >>> square = Polyhedron((1, 2, 3, 4), [(0, 1, 3, 2)], pgroup) + + To rotate the square with a single permutation we can do: + + >>> square.rotate(square.pgroup[0]) + >>> square.corners + (1, 3, 2, 4) + + To use more than one permutation (or to use one permutation more + than once) it is more convenient to use the make_perm method: + + >>> p011 = square.pgroup.make_perm([0, 1, 1]) # diag flip + 2 rotations + >>> square.reset() # return to initial orientation + >>> square.rotate(p011) + >>> square.corners + (4, 2, 3, 1) + + Thinking outside the box + ------------------------ + + Although the Polyhedron object has a direct physical meaning, it + actually has broader application. In the most general sense it is + just a decorated PermutationGroup, allowing one to connect the + permutations to something physical. For example, a Rubik's cube is + not a proper polyhedron, but the Polyhedron class can be used to + represent it in a way that helps to visualize the Rubik's cube. + + >>> from sympy import flatten, unflatten, symbols + >>> from sympy.combinatorics import RubikGroup + >>> facelets = flatten([symbols(s+'1:5') for s in 'UFRBLD']) + >>> def show(): + ... pairs = unflatten(r2.corners, 2) + ... print(pairs[::2]) + ... print(pairs[1::2]) + ... + >>> r2 = Polyhedron(facelets, pgroup=RubikGroup(2)) + >>> show() + [(U1, U2), (F1, F2), (R1, R2), (B1, B2), (L1, L2), (D1, D2)] + [(U3, U4), (F3, F4), (R3, R4), (B3, B4), (L3, L4), (D3, D4)] + >>> r2.rotate(0) # cw rotation of F + >>> show() + [(U1, U2), (F3, F1), (U3, R2), (B1, B2), (L1, D1), (R3, R1)] + [(L4, L2), (F4, F2), (U4, R4), (B3, B4), (L3, D2), (D3, D4)] + + Predefined Polyhedra + ==================== + + For convenience, the vertices and faces are defined for the following + standard solids along with a permutation group for transformations. + When the polyhedron is oriented as indicated below, the vertices in + a given horizontal plane are numbered in ccw direction, starting from + the vertex that will give the lowest indices in a given face. (In the + net of the vertices, indices preceded by "-" indicate replication of + the lhs index in the net.) + + tetrahedron, tetrahedron_faces + ------------------------------ + + 4 vertices (vertex up) net: + + 0 0-0 + 1 2 3-1 + + 4 faces: + + (0, 1, 2) (0, 2, 3) (0, 3, 1) (1, 2, 3) + + cube, cube_faces + ---------------- + + 8 vertices (face up) net: + + 0 1 2 3-0 + 4 5 6 7-4 + + 6 faces: + + (0, 1, 2, 3) + (0, 1, 5, 4) (1, 2, 6, 5) (2, 3, 7, 6) (0, 3, 7, 4) + (4, 5, 6, 7) + + octahedron, octahedron_faces + ---------------------------- + + 6 vertices (vertex up) net: + + 0 0 0-0 + 1 2 3 4-1 + 5 5 5-5 + + 8 faces: + + (0, 1, 2) (0, 2, 3) (0, 3, 4) (0, 1, 4) + (1, 2, 5) (2, 3, 5) (3, 4, 5) (1, 4, 5) + + dodecahedron, dodecahedron_faces + -------------------------------- + + 20 vertices (vertex up) net: + + 0 1 2 3 4 -0 + 5 6 7 8 9 -5 + 14 10 11 12 13-14 + 15 16 17 18 19-15 + + 12 faces: + + (0, 1, 2, 3, 4) (0, 1, 6, 10, 5) (1, 2, 7, 11, 6) + (2, 3, 8, 12, 7) (3, 4, 9, 13, 8) (0, 4, 9, 14, 5) + (5, 10, 16, 15, 14) (6, 10, 16, 17, 11) (7, 11, 17, 18, 12) + (8, 12, 18, 19, 13) (9, 13, 19, 15, 14)(15, 16, 17, 18, 19) + + icosahedron, icosahedron_faces + ------------------------------ + + 12 vertices (face up) net: + + 0 0 0 0 -0 + 1 2 3 4 5 -1 + 6 7 8 9 10 -6 + 11 11 11 11 -11 + + 20 faces: + + (0, 1, 2) (0, 2, 3) (0, 3, 4) + (0, 4, 5) (0, 1, 5) (1, 2, 6) + (2, 3, 7) (3, 4, 8) (4, 5, 9) + (1, 5, 10) (2, 6, 7) (3, 7, 8) + (4, 8, 9) (5, 9, 10) (1, 6, 10) + (6, 7, 11) (7, 8, 11) (8, 9, 11) + (9, 10, 11) (6, 10, 11) + + >>> from sympy.combinatorics.polyhedron import cube + >>> cube.edges + {(0, 1), (0, 3), (0, 4), (1, 2), (1, 5), (2, 3), (2, 6), (3, 7), (4, 5), (4, 7), (5, 6), (6, 7)} + + If you want to use letters or other names for the corners you + can still use the pre-calculated faces: + + >>> corners = list('abcdefgh') + >>> Polyhedron(corners, cube.faces).corners + (a, b, c, d, e, f, g, h) + + References + ========== + + .. [1] www.ocf.berkeley.edu/~wwu/articles/platonicsolids.pdf + + """ + faces = [minlex(f, directed=False, key=default_sort_key) for f in faces] + corners, faces, pgroup = args = \ + [Tuple(*a) for a in (corners, faces, pgroup)] + obj = Basic.__new__(cls, *args) + obj._corners = tuple(corners) # in order given + obj._faces = FiniteSet(*faces) + if pgroup and pgroup[0].size != len(corners): + raise ValueError("Permutation size unequal to number of corners.") + # use the identity permutation if none are given + obj._pgroup = PermutationGroup( + pgroup or [Perm(range(len(corners)))] ) + return obj + + @property + def corners(self): + """ + Get the corners of the Polyhedron. + + The method ``vertices`` is an alias for ``corners``. + + Examples + ======== + + >>> from sympy.combinatorics import Polyhedron + >>> from sympy.abc import a, b, c, d + >>> p = Polyhedron(list('abcd')) + >>> p.corners == p.vertices == (a, b, c, d) + True + + See Also + ======== + + array_form, cyclic_form + """ + return self._corners + vertices = corners + + @property + def array_form(self): + """Return the indices of the corners. + + The indices are given relative to the original position of corners. + + Examples + ======== + + >>> from sympy.combinatorics.polyhedron import tetrahedron + >>> tetrahedron = tetrahedron.copy() + >>> tetrahedron.array_form + [0, 1, 2, 3] + + >>> tetrahedron.rotate(0) + >>> tetrahedron.array_form + [0, 2, 3, 1] + >>> tetrahedron.pgroup[0].array_form + [0, 2, 3, 1] + + See Also + ======== + + corners, cyclic_form + """ + corners = list(self.args[0]) + return [corners.index(c) for c in self.corners] + + @property + def cyclic_form(self): + """Return the indices of the corners in cyclic notation. + + The indices are given relative to the original position of corners. + + See Also + ======== + + corners, array_form + """ + return Perm._af_new(self.array_form).cyclic_form + + @property + def size(self): + """ + Get the number of corners of the Polyhedron. + """ + return len(self._corners) + + @property + def faces(self): + """ + Get the faces of the Polyhedron. + """ + return self._faces + + @property + def pgroup(self): + """ + Get the permutations of the Polyhedron. + """ + return self._pgroup + + @property + def edges(self): + """ + Given the faces of the polyhedra we can get the edges. + + Examples + ======== + + >>> from sympy.combinatorics import Polyhedron + >>> from sympy.abc import a, b, c + >>> corners = (a, b, c) + >>> faces = [(0, 1, 2)] + >>> Polyhedron(corners, faces).edges + {(0, 1), (0, 2), (1, 2)} + + """ + if self._edges is None: + output = set() + for face in self.faces: + for i in range(len(face)): + edge = tuple(sorted([face[i], face[i - 1]])) + output.add(edge) + self._edges = FiniteSet(*output) + return self._edges + + def rotate(self, perm): + """ + Apply a permutation to the polyhedron *in place*. The permutation + may be given as a Permutation instance or an integer indicating + which permutation from pgroup of the Polyhedron should be + applied. + + This is an operation that is analogous to rotation about + an axis by a fixed increment. + + Notes + ===== + + When a Permutation is applied, no check is done to see if that + is a valid permutation for the Polyhedron. For example, a cube + could be given a permutation which effectively swaps only 2 + vertices. A valid permutation (that rotates the object in a + physical way) will be obtained if one only uses + permutations from the ``pgroup`` of the Polyhedron. On the other + hand, allowing arbitrary rotations (applications of permutations) + gives a way to follow named elements rather than indices since + Polyhedron allows vertices to be named while Permutation works + only with indices. + + Examples + ======== + + >>> from sympy.combinatorics import Polyhedron, Permutation + >>> from sympy.combinatorics.polyhedron import cube + >>> cube = cube.copy() + >>> cube.corners + (0, 1, 2, 3, 4, 5, 6, 7) + >>> cube.rotate(0) + >>> cube.corners + (1, 2, 3, 0, 5, 6, 7, 4) + + A non-physical "rotation" that is not prohibited by this method: + + >>> cube.reset() + >>> cube.rotate(Permutation([[1, 2]], size=8)) + >>> cube.corners + (0, 2, 1, 3, 4, 5, 6, 7) + + Polyhedron can be used to follow elements of set that are + identified by letters instead of integers: + + >>> shadow = h5 = Polyhedron(list('abcde')) + >>> p = Permutation([3, 0, 1, 2, 4]) + >>> h5.rotate(p) + >>> h5.corners + (d, a, b, c, e) + >>> _ == shadow.corners + True + >>> copy = h5.copy() + >>> h5.rotate(p) + >>> h5.corners == copy.corners + False + """ + if not isinstance(perm, Perm): + perm = self.pgroup[perm] + # and we know it's valid + else: + if perm.size != self.size: + raise ValueError('Polyhedron and Permutation sizes differ.') + a = perm.array_form + corners = [self.corners[a[i]] for i in range(len(self.corners))] + self._corners = tuple(corners) + + def reset(self): + """Return corners to their original positions. + + Examples + ======== + + >>> from sympy.combinatorics.polyhedron import tetrahedron as T + >>> T = T.copy() + >>> T.corners + (0, 1, 2, 3) + >>> T.rotate(0) + >>> T.corners + (0, 2, 3, 1) + >>> T.reset() + >>> T.corners + (0, 1, 2, 3) + """ + self._corners = self.args[0] + + +def _pgroup_calcs(): + """Return the permutation groups for each of the polyhedra and the face + definitions: tetrahedron, cube, octahedron, dodecahedron, icosahedron, + tetrahedron_faces, cube_faces, octahedron_faces, dodecahedron_faces, + icosahedron_faces + + Explanation + =========== + + (This author did not find and did not know of a better way to do it though + there likely is such a way.) + + Although only 2 permutations are needed for a polyhedron in order to + generate all the possible orientations, a group of permutations is + provided instead. A set of permutations is called a "group" if:: + + a*b = c (for any pair of permutations in the group, a and b, their + product, c, is in the group) + + a*(b*c) = (a*b)*c (for any 3 permutations in the group associativity holds) + + there is an identity permutation, I, such that I*a = a*I for all elements + in the group + + a*b = I (the inverse of each permutation is also in the group) + + None of the polyhedron groups defined follow these definitions of a group. + Instead, they are selected to contain those permutations whose powers + alone will construct all orientations of the polyhedron, i.e. for + permutations ``a``, ``b``, etc... in the group, ``a, a**2, ..., a**o_a``, + ``b, b**2, ..., b**o_b``, etc... (where ``o_i`` is the order of + permutation ``i``) generate all permutations of the polyhedron instead of + mixed products like ``a*b``, ``a*b**2``, etc.... + + Note that for a polyhedron with n vertices, the valid permutations of the + vertices exclude those that do not maintain its faces. e.g. the + permutation BCDE of a square's four corners, ABCD, is a valid + permutation while CBDE is not (because this would twist the square). + + Examples + ======== + + The is_group checks for: closure, the presence of the Identity permutation, + and the presence of the inverse for each of the elements in the group. This + confirms that none of the polyhedra are true groups: + + >>> from sympy.combinatorics.polyhedron import ( + ... tetrahedron, cube, octahedron, dodecahedron, icosahedron) + ... + >>> polyhedra = (tetrahedron, cube, octahedron, dodecahedron, icosahedron) + >>> [h.pgroup.is_group for h in polyhedra] + ... + [True, True, True, True, True] + + Although tests in polyhedron's test suite check that powers of the + permutations in the groups generate all permutations of the vertices + of the polyhedron, here we also demonstrate the powers of the given + permutations create a complete group for the tetrahedron: + + >>> from sympy.combinatorics import Permutation, PermutationGroup + >>> for h in polyhedra[:1]: + ... G = h.pgroup + ... perms = set() + ... for g in G: + ... for e in range(g.order()): + ... p = tuple((g**e).array_form) + ... perms.add(p) + ... + ... perms = [Permutation(p) for p in perms] + ... assert PermutationGroup(perms).is_group + + In addition to doing the above, the tests in the suite confirm that the + faces are all present after the application of each permutation. + + References + ========== + + .. [1] https://dogschool.tripod.com/trianglegroup.html + + """ + def _pgroup_of_double(polyh, ordered_faces, pgroup): + n = len(ordered_faces[0]) + # the vertices of the double which sits inside a give polyhedron + # can be found by tracking the faces of the outer polyhedron. + # A map between face and the vertex of the double is made so that + # after rotation the position of the vertices can be located + fmap = dict(zip(ordered_faces, + range(len(ordered_faces)))) + flat_faces = flatten(ordered_faces) + new_pgroup = [] + for p in pgroup: + h = polyh.copy() + h.rotate(p) + c = h.corners + # reorder corners in the order they should appear when + # enumerating the faces + reorder = unflatten([c[j] for j in flat_faces], n) + # make them canonical + reorder = [tuple(map(as_int, + minlex(f, directed=False))) + for f in reorder] + # map face to vertex: the resulting list of vertices are the + # permutation that we seek for the double + new_pgroup.append(Perm([fmap[f] for f in reorder])) + return new_pgroup + + tetrahedron_faces = [ + (0, 1, 2), (0, 2, 3), (0, 3, 1), # upper 3 + (1, 2, 3), # bottom + ] + + # cw from top + # + _t_pgroup = [ + Perm([[1, 2, 3], [0]]), # cw from top + Perm([[0, 1, 2], [3]]), # cw from front face + Perm([[0, 3, 2], [1]]), # cw from back right face + Perm([[0, 3, 1], [2]]), # cw from back left face + Perm([[0, 1], [2, 3]]), # through front left edge + Perm([[0, 2], [1, 3]]), # through front right edge + Perm([[0, 3], [1, 2]]), # through back edge + ] + + tetrahedron = Polyhedron( + range(4), + tetrahedron_faces, + _t_pgroup) + + cube_faces = [ + (0, 1, 2, 3), # upper + (0, 1, 5, 4), (1, 2, 6, 5), (2, 3, 7, 6), (0, 3, 7, 4), # middle 4 + (4, 5, 6, 7), # lower + ] + + # U, D, F, B, L, R = up, down, front, back, left, right + _c_pgroup = [Perm(p) for p in + [ + [1, 2, 3, 0, 5, 6, 7, 4], # cw from top, U + [4, 0, 3, 7, 5, 1, 2, 6], # cw from F face + [4, 5, 1, 0, 7, 6, 2, 3], # cw from R face + + [1, 0, 4, 5, 2, 3, 7, 6], # cw through UF edge + [6, 2, 1, 5, 7, 3, 0, 4], # cw through UR edge + [6, 7, 3, 2, 5, 4, 0, 1], # cw through UB edge + [3, 7, 4, 0, 2, 6, 5, 1], # cw through UL edge + [4, 7, 6, 5, 0, 3, 2, 1], # cw through FL edge + [6, 5, 4, 7, 2, 1, 0, 3], # cw through FR edge + + [0, 3, 7, 4, 1, 2, 6, 5], # cw through UFL vertex + [5, 1, 0, 4, 6, 2, 3, 7], # cw through UFR vertex + [5, 6, 2, 1, 4, 7, 3, 0], # cw through UBR vertex + [7, 4, 0, 3, 6, 5, 1, 2], # cw through UBL + ]] + + cube = Polyhedron( + range(8), + cube_faces, + _c_pgroup) + + octahedron_faces = [ + (0, 1, 2), (0, 2, 3), (0, 3, 4), (0, 1, 4), # top 4 + (1, 2, 5), (2, 3, 5), (3, 4, 5), (1, 4, 5), # bottom 4 + ] + + octahedron = Polyhedron( + range(6), + octahedron_faces, + _pgroup_of_double(cube, cube_faces, _c_pgroup)) + + dodecahedron_faces = [ + (0, 1, 2, 3, 4), # top + (0, 1, 6, 10, 5), (1, 2, 7, 11, 6), (2, 3, 8, 12, 7), # upper 5 + (3, 4, 9, 13, 8), (0, 4, 9, 14, 5), + (5, 10, 16, 15, 14), (6, 10, 16, 17, 11), (7, 11, 17, 18, + 12), # lower 5 + (8, 12, 18, 19, 13), (9, 13, 19, 15, 14), + (15, 16, 17, 18, 19) # bottom + ] + + def _string_to_perm(s): + rv = [Perm(range(20))] + p = None + for si in s: + if si not in '01': + count = int(si) - 1 + else: + count = 1 + if si == '0': + p = _f0 + elif si == '1': + p = _f1 + rv.extend([p]*count) + return Perm.rmul(*rv) + + # top face cw + _f0 = Perm([ + 1, 2, 3, 4, 0, 6, 7, 8, 9, 5, 11, + 12, 13, 14, 10, 16, 17, 18, 19, 15]) + # front face cw + _f1 = Perm([ + 5, 0, 4, 9, 14, 10, 1, 3, 13, 15, + 6, 2, 8, 19, 16, 17, 11, 7, 12, 18]) + # the strings below, like 0104 are shorthand for F0*F1*F0**4 and are + # the remaining 4 face rotations, 15 edge permutations, and the + # 10 vertex rotations. + _dodeca_pgroup = [_f0, _f1] + [_string_to_perm(s) for s in ''' + 0104 140 014 0410 + 010 1403 03104 04103 102 + 120 1304 01303 021302 03130 + 0412041 041204103 04120410 041204104 041204102 + 10 01 1402 0140 04102 0412 1204 1302 0130 03120'''.strip().split()] + + dodecahedron = Polyhedron( + range(20), + dodecahedron_faces, + _dodeca_pgroup) + + icosahedron_faces = [ + (0, 1, 2), (0, 2, 3), (0, 3, 4), (0, 4, 5), (0, 1, 5), + (1, 6, 7), (1, 2, 7), (2, 7, 8), (2, 3, 8), (3, 8, 9), + (3, 4, 9), (4, 9, 10), (4, 5, 10), (5, 6, 10), (1, 5, 6), + (6, 7, 11), (7, 8, 11), (8, 9, 11), (9, 10, 11), (6, 10, 11)] + + icosahedron = Polyhedron( + range(12), + icosahedron_faces, + _pgroup_of_double( + dodecahedron, dodecahedron_faces, _dodeca_pgroup)) + + return (tetrahedron, cube, octahedron, dodecahedron, icosahedron, + tetrahedron_faces, cube_faces, octahedron_faces, + dodecahedron_faces, icosahedron_faces) + +# ----------------------------------------------------------------------- +# Standard Polyhedron groups +# +# These are generated using _pgroup_calcs() above. However to save +# import time we encode them explicitly here. +# ----------------------------------------------------------------------- + +tetrahedron = Polyhedron( + Tuple(0, 1, 2, 3), + Tuple( + Tuple(0, 1, 2), + Tuple(0, 2, 3), + Tuple(0, 1, 3), + Tuple(1, 2, 3)), + Tuple( + Perm(1, 2, 3), + Perm(3)(0, 1, 2), + Perm(0, 3, 2), + Perm(0, 3, 1), + Perm(0, 1)(2, 3), + Perm(0, 2)(1, 3), + Perm(0, 3)(1, 2) + )) + +cube = Polyhedron( + Tuple(0, 1, 2, 3, 4, 5, 6, 7), + Tuple( + Tuple(0, 1, 2, 3), + Tuple(0, 1, 5, 4), + Tuple(1, 2, 6, 5), + Tuple(2, 3, 7, 6), + Tuple(0, 3, 7, 4), + Tuple(4, 5, 6, 7)), + Tuple( + Perm(0, 1, 2, 3)(4, 5, 6, 7), + Perm(0, 4, 5, 1)(2, 3, 7, 6), + Perm(0, 4, 7, 3)(1, 5, 6, 2), + Perm(0, 1)(2, 4)(3, 5)(6, 7), + Perm(0, 6)(1, 2)(3, 5)(4, 7), + Perm(0, 6)(1, 7)(2, 3)(4, 5), + Perm(0, 3)(1, 7)(2, 4)(5, 6), + Perm(0, 4)(1, 7)(2, 6)(3, 5), + Perm(0, 6)(1, 5)(2, 4)(3, 7), + Perm(1, 3, 4)(2, 7, 5), + Perm(7)(0, 5, 2)(3, 4, 6), + Perm(0, 5, 7)(1, 6, 3), + Perm(0, 7, 2)(1, 4, 6))) + +octahedron = Polyhedron( + Tuple(0, 1, 2, 3, 4, 5), + Tuple( + Tuple(0, 1, 2), + Tuple(0, 2, 3), + Tuple(0, 3, 4), + Tuple(0, 1, 4), + Tuple(1, 2, 5), + Tuple(2, 3, 5), + Tuple(3, 4, 5), + Tuple(1, 4, 5)), + Tuple( + Perm(5)(1, 2, 3, 4), + Perm(0, 4, 5, 2), + Perm(0, 1, 5, 3), + Perm(0, 1)(2, 4)(3, 5), + Perm(0, 2)(1, 3)(4, 5), + Perm(0, 3)(1, 5)(2, 4), + Perm(0, 4)(1, 3)(2, 5), + Perm(0, 5)(1, 4)(2, 3), + Perm(0, 5)(1, 2)(3, 4), + Perm(0, 4, 1)(2, 3, 5), + Perm(0, 1, 2)(3, 4, 5), + Perm(0, 2, 3)(1, 5, 4), + Perm(0, 4, 3)(1, 5, 2))) + +dodecahedron = Polyhedron( + Tuple(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19), + Tuple( + Tuple(0, 1, 2, 3, 4), + Tuple(0, 1, 6, 10, 5), + Tuple(1, 2, 7, 11, 6), + Tuple(2, 3, 8, 12, 7), + Tuple(3, 4, 9, 13, 8), + Tuple(0, 4, 9, 14, 5), + Tuple(5, 10, 16, 15, 14), + Tuple(6, 10, 16, 17, 11), + Tuple(7, 11, 17, 18, 12), + Tuple(8, 12, 18, 19, 13), + Tuple(9, 13, 19, 15, 14), + Tuple(15, 16, 17, 18, 19)), + Tuple( + Perm(0, 1, 2, 3, 4)(5, 6, 7, 8, 9)(10, 11, 12, 13, 14)(15, 16, 17, 18, 19), + Perm(0, 5, 10, 6, 1)(2, 4, 14, 16, 11)(3, 9, 15, 17, 7)(8, 13, 19, 18, 12), + Perm(0, 10, 17, 12, 3)(1, 6, 11, 7, 2)(4, 5, 16, 18, 8)(9, 14, 15, 19, 13), + Perm(0, 6, 17, 19, 9)(1, 11, 18, 13, 4)(2, 7, 12, 8, 3)(5, 10, 16, 15, 14), + Perm(0, 2, 12, 19, 14)(1, 7, 18, 15, 5)(3, 8, 13, 9, 4)(6, 11, 17, 16, 10), + Perm(0, 4, 9, 14, 5)(1, 3, 13, 15, 10)(2, 8, 19, 16, 6)(7, 12, 18, 17, 11), + Perm(0, 1)(2, 5)(3, 10)(4, 6)(7, 14)(8, 16)(9, 11)(12, 15)(13, 17)(18, 19), + Perm(0, 7)(1, 2)(3, 6)(4, 11)(5, 12)(8, 10)(9, 17)(13, 16)(14, 18)(15, 19), + Perm(0, 12)(1, 8)(2, 3)(4, 7)(5, 18)(6, 13)(9, 11)(10, 19)(14, 17)(15, 16), + Perm(0, 8)(1, 13)(2, 9)(3, 4)(5, 12)(6, 19)(7, 14)(10, 18)(11, 15)(16, 17), + Perm(0, 4)(1, 9)(2, 14)(3, 5)(6, 13)(7, 15)(8, 10)(11, 19)(12, 16)(17, 18), + Perm(0, 5)(1, 14)(2, 15)(3, 16)(4, 10)(6, 9)(7, 19)(8, 17)(11, 13)(12, 18), + Perm(0, 11)(1, 6)(2, 10)(3, 16)(4, 17)(5, 7)(8, 15)(9, 18)(12, 14)(13, 19), + Perm(0, 18)(1, 12)(2, 7)(3, 11)(4, 17)(5, 19)(6, 8)(9, 16)(10, 13)(14, 15), + Perm(0, 18)(1, 19)(2, 13)(3, 8)(4, 12)(5, 17)(6, 15)(7, 9)(10, 16)(11, 14), + Perm(0, 13)(1, 19)(2, 15)(3, 14)(4, 9)(5, 8)(6, 18)(7, 16)(10, 12)(11, 17), + Perm(0, 16)(1, 15)(2, 19)(3, 18)(4, 17)(5, 10)(6, 14)(7, 13)(8, 12)(9, 11), + Perm(0, 18)(1, 17)(2, 16)(3, 15)(4, 19)(5, 12)(6, 11)(7, 10)(8, 14)(9, 13), + Perm(0, 15)(1, 19)(2, 18)(3, 17)(4, 16)(5, 14)(6, 13)(7, 12)(8, 11)(9, 10), + Perm(0, 17)(1, 16)(2, 15)(3, 19)(4, 18)(5, 11)(6, 10)(7, 14)(8, 13)(9, 12), + Perm(0, 19)(1, 18)(2, 17)(3, 16)(4, 15)(5, 13)(6, 12)(7, 11)(8, 10)(9, 14), + Perm(1, 4, 5)(2, 9, 10)(3, 14, 6)(7, 13, 16)(8, 15, 11)(12, 19, 17), + Perm(19)(0, 6, 2)(3, 5, 11)(4, 10, 7)(8, 14, 17)(9, 16, 12)(13, 15, 18), + Perm(0, 11, 8)(1, 7, 3)(4, 6, 12)(5, 17, 13)(9, 10, 18)(14, 16, 19), + Perm(0, 7, 13)(1, 12, 9)(2, 8, 4)(5, 11, 19)(6, 18, 14)(10, 17, 15), + Perm(0, 3, 9)(1, 8, 14)(2, 13, 5)(6, 12, 15)(7, 19, 10)(11, 18, 16), + Perm(0, 14, 10)(1, 9, 16)(2, 13, 17)(3, 19, 11)(4, 15, 6)(7, 8, 18), + Perm(0, 16, 7)(1, 10, 11)(2, 5, 17)(3, 14, 18)(4, 15, 12)(8, 9, 19), + Perm(0, 16, 13)(1, 17, 8)(2, 11, 12)(3, 6, 18)(4, 10, 19)(5, 15, 9), + Perm(0, 11, 15)(1, 17, 14)(2, 18, 9)(3, 12, 13)(4, 7, 19)(5, 6, 16), + Perm(0, 8, 15)(1, 12, 16)(2, 18, 10)(3, 19, 5)(4, 13, 14)(6, 7, 17))) + +icosahedron = Polyhedron( + Tuple(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11), + Tuple( + Tuple(0, 1, 2), + Tuple(0, 2, 3), + Tuple(0, 3, 4), + Tuple(0, 4, 5), + Tuple(0, 1, 5), + Tuple(1, 6, 7), + Tuple(1, 2, 7), + Tuple(2, 7, 8), + Tuple(2, 3, 8), + Tuple(3, 8, 9), + Tuple(3, 4, 9), + Tuple(4, 9, 10), + Tuple(4, 5, 10), + Tuple(5, 6, 10), + Tuple(1, 5, 6), + Tuple(6, 7, 11), + Tuple(7, 8, 11), + Tuple(8, 9, 11), + Tuple(9, 10, 11), + Tuple(6, 10, 11)), + Tuple( + Perm(11)(1, 2, 3, 4, 5)(6, 7, 8, 9, 10), + Perm(0, 5, 6, 7, 2)(3, 4, 10, 11, 8), + Perm(0, 1, 7, 8, 3)(4, 5, 6, 11, 9), + Perm(0, 2, 8, 9, 4)(1, 7, 11, 10, 5), + Perm(0, 3, 9, 10, 5)(1, 2, 8, 11, 6), + Perm(0, 4, 10, 6, 1)(2, 3, 9, 11, 7), + Perm(0, 1)(2, 5)(3, 6)(4, 7)(8, 10)(9, 11), + Perm(0, 2)(1, 3)(4, 7)(5, 8)(6, 9)(10, 11), + Perm(0, 3)(1, 9)(2, 4)(5, 8)(6, 11)(7, 10), + Perm(0, 4)(1, 9)(2, 10)(3, 5)(6, 8)(7, 11), + Perm(0, 5)(1, 4)(2, 10)(3, 6)(7, 9)(8, 11), + Perm(0, 6)(1, 5)(2, 10)(3, 11)(4, 7)(8, 9), + Perm(0, 7)(1, 2)(3, 6)(4, 11)(5, 8)(9, 10), + Perm(0, 8)(1, 9)(2, 3)(4, 7)(5, 11)(6, 10), + Perm(0, 9)(1, 11)(2, 10)(3, 4)(5, 8)(6, 7), + Perm(0, 10)(1, 9)(2, 11)(3, 6)(4, 5)(7, 8), + Perm(0, 11)(1, 6)(2, 10)(3, 9)(4, 8)(5, 7), + Perm(0, 11)(1, 8)(2, 7)(3, 6)(4, 10)(5, 9), + Perm(0, 11)(1, 10)(2, 9)(3, 8)(4, 7)(5, 6), + Perm(0, 11)(1, 7)(2, 6)(3, 10)(4, 9)(5, 8), + Perm(0, 11)(1, 9)(2, 8)(3, 7)(4, 6)(5, 10), + Perm(0, 5, 1)(2, 4, 6)(3, 10, 7)(8, 9, 11), + Perm(0, 1, 2)(3, 5, 7)(4, 6, 8)(9, 10, 11), + Perm(0, 2, 3)(1, 8, 4)(5, 7, 9)(6, 11, 10), + Perm(0, 3, 4)(1, 8, 10)(2, 9, 5)(6, 7, 11), + Perm(0, 4, 5)(1, 3, 10)(2, 9, 6)(7, 8, 11), + Perm(0, 10, 7)(1, 5, 6)(2, 4, 11)(3, 9, 8), + Perm(0, 6, 8)(1, 7, 2)(3, 5, 11)(4, 10, 9), + Perm(0, 7, 9)(1, 11, 4)(2, 8, 3)(5, 6, 10), + Perm(0, 8, 10)(1, 7, 6)(2, 11, 5)(3, 9, 4), + Perm(0, 9, 6)(1, 3, 11)(2, 8, 7)(4, 10, 5))) + +tetrahedron_faces = [tuple(arg) for arg in tetrahedron.faces] + +cube_faces = [tuple(arg) for arg in cube.faces] + +octahedron_faces = [tuple(arg) for arg in octahedron.faces] + +dodecahedron_faces = [tuple(arg) for arg in dodecahedron.faces] + +icosahedron_faces = [tuple(arg) for arg in icosahedron.faces] diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/prufer.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/prufer.py new file mode 100644 index 0000000000000000000000000000000000000000..e389df87cddc0152b2376e18b9f3df2e94d3d2fb --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/prufer.py @@ -0,0 +1,435 @@ +from sympy.core import Basic +from sympy.core.containers import Tuple +from sympy.tensor.array import Array +from sympy.core.sympify import _sympify +from sympy.utilities.iterables import flatten, iterable +from sympy.utilities.misc import as_int + +from collections import defaultdict + + +class Prufer(Basic): + """ + The Prufer correspondence is an algorithm that describes the + bijection between labeled trees and the Prufer code. A Prufer + code of a labeled tree is unique up to isomorphism and has + a length of n - 2. + + Prufer sequences were first used by Heinz Prufer to give a + proof of Cayley's formula. + + References + ========== + + .. [1] https://mathworld.wolfram.com/LabeledTree.html + + """ + _prufer_repr = None + _tree_repr = None + _nodes = None + _rank = None + + @property + def prufer_repr(self): + """Returns Prufer sequence for the Prufer object. + + This sequence is found by removing the highest numbered vertex, + recording the node it was attached to, and continuing until only + two vertices remain. The Prufer sequence is the list of recorded nodes. + + Examples + ======== + + >>> from sympy.combinatorics.prufer import Prufer + >>> Prufer([[0, 3], [1, 3], [2, 3], [3, 4], [4, 5]]).prufer_repr + [3, 3, 3, 4] + >>> Prufer([1, 0, 0]).prufer_repr + [1, 0, 0] + + See Also + ======== + + to_prufer + + """ + if self._prufer_repr is None: + self._prufer_repr = self.to_prufer(self._tree_repr[:], self.nodes) + return self._prufer_repr + + @property + def tree_repr(self): + """Returns the tree representation of the Prufer object. + + Examples + ======== + + >>> from sympy.combinatorics.prufer import Prufer + >>> Prufer([[0, 3], [1, 3], [2, 3], [3, 4], [4, 5]]).tree_repr + [[0, 3], [1, 3], [2, 3], [3, 4], [4, 5]] + >>> Prufer([1, 0, 0]).tree_repr + [[1, 2], [0, 1], [0, 3], [0, 4]] + + See Also + ======== + + to_tree + + """ + if self._tree_repr is None: + self._tree_repr = self.to_tree(self._prufer_repr[:]) + return self._tree_repr + + @property + def nodes(self): + """Returns the number of nodes in the tree. + + Examples + ======== + + >>> from sympy.combinatorics.prufer import Prufer + >>> Prufer([[0, 3], [1, 3], [2, 3], [3, 4], [4, 5]]).nodes + 6 + >>> Prufer([1, 0, 0]).nodes + 5 + + """ + return self._nodes + + @property + def rank(self): + """Returns the rank of the Prufer sequence. + + Examples + ======== + + >>> from sympy.combinatorics.prufer import Prufer + >>> p = Prufer([[0, 3], [1, 3], [2, 3], [3, 4], [4, 5]]) + >>> p.rank + 778 + >>> p.next(1).rank + 779 + >>> p.prev().rank + 777 + + See Also + ======== + + prufer_rank, next, prev, size + + """ + if self._rank is None: + self._rank = self.prufer_rank() + return self._rank + + @property + def size(self): + """Return the number of possible trees of this Prufer object. + + Examples + ======== + + >>> from sympy.combinatorics.prufer import Prufer + >>> Prufer([0]*4).size == Prufer([6]*4).size == 1296 + True + + See Also + ======== + + prufer_rank, rank, next, prev + + """ + return self.prev(self.rank).prev().rank + 1 + + @staticmethod + def to_prufer(tree, n): + """Return the Prufer sequence for a tree given as a list of edges where + ``n`` is the number of nodes in the tree. + + Examples + ======== + + >>> from sympy.combinatorics.prufer import Prufer + >>> a = Prufer([[0, 1], [0, 2], [0, 3]]) + >>> a.prufer_repr + [0, 0] + >>> Prufer.to_prufer([[0, 1], [0, 2], [0, 3]], 4) + [0, 0] + + See Also + ======== + prufer_repr: returns Prufer sequence of a Prufer object. + + """ + d = defaultdict(int) + L = [] + for edge in tree: + # Increment the value of the corresponding + # node in the degree list as we encounter an + # edge involving it. + d[edge[0]] += 1 + d[edge[1]] += 1 + for i in range(n - 2): + # find the smallest leaf + for x in range(n): + if d[x] == 1: + break + # find the node it was connected to + y = None + for edge in tree: + if x == edge[0]: + y = edge[1] + elif x == edge[1]: + y = edge[0] + if y is not None: + break + # record and update + L.append(y) + for j in (x, y): + d[j] -= 1 + if not d[j]: + d.pop(j) + tree.remove(edge) + return L + + @staticmethod + def to_tree(prufer): + """Return the tree (as a list of edges) of the given Prufer sequence. + + Examples + ======== + + >>> from sympy.combinatorics.prufer import Prufer + >>> a = Prufer([0, 2], 4) + >>> a.tree_repr + [[0, 1], [0, 2], [2, 3]] + >>> Prufer.to_tree([0, 2]) + [[0, 1], [0, 2], [2, 3]] + + References + ========== + + .. [1] https://hamberg.no/erlend/posts/2010-11-06-prufer-sequence-compact-tree-representation.html + + See Also + ======== + tree_repr: returns tree representation of a Prufer object. + + """ + tree = [] + last = [] + n = len(prufer) + 2 + d = defaultdict(lambda: 1) + for p in prufer: + d[p] += 1 + for i in prufer: + for j in range(n): + # find the smallest leaf (degree = 1) + if d[j] == 1: + break + # (i, j) is the new edge that we append to the tree + # and remove from the degree dictionary + d[i] -= 1 + d[j] -= 1 + tree.append(sorted([i, j])) + last = [i for i in range(n) if d[i] == 1] or [0, 1] + tree.append(last) + + return tree + + @staticmethod + def edges(*runs): + """Return a list of edges and the number of nodes from the given runs + that connect nodes in an integer-labelled tree. + + All node numbers will be shifted so that the minimum node is 0. It is + not a problem if edges are repeated in the runs; only unique edges are + returned. There is no assumption made about what the range of the node + labels should be, but all nodes from the smallest through the largest + must be present. + + Examples + ======== + + >>> from sympy.combinatorics.prufer import Prufer + >>> Prufer.edges([1, 2, 3], [2, 4, 5]) # a T + ([[0, 1], [1, 2], [1, 3], [3, 4]], 5) + + Duplicate edges are removed: + + >>> Prufer.edges([0, 1, 2, 3], [1, 4, 5], [1, 4, 6]) # a K + ([[0, 1], [1, 2], [1, 4], [2, 3], [4, 5], [4, 6]], 7) + + """ + e = set() + nmin = runs[0][0] + for r in runs: + for i in range(len(r) - 1): + a, b = r[i: i + 2] + if b < a: + a, b = b, a + e.add((a, b)) + rv = [] + got = set() + nmin = nmax = None + for ei in e: + got.update(ei) + nmin = min(ei[0], nmin) if nmin is not None else ei[0] + nmax = max(ei[1], nmax) if nmax is not None else ei[1] + rv.append(list(ei)) + missing = set(range(nmin, nmax + 1)) - got + if missing: + missing = [i + nmin for i in missing] + if len(missing) == 1: + msg = 'Node %s is missing.' % missing.pop() + else: + msg = 'Nodes %s are missing.' % sorted(missing) + raise ValueError(msg) + if nmin != 0: + for i, ei in enumerate(rv): + rv[i] = [n - nmin for n in ei] + nmax -= nmin + return sorted(rv), nmax + 1 + + def prufer_rank(self): + """Computes the rank of a Prufer sequence. + + Examples + ======== + + >>> from sympy.combinatorics.prufer import Prufer + >>> a = Prufer([[0, 1], [0, 2], [0, 3]]) + >>> a.prufer_rank() + 0 + + See Also + ======== + + rank, next, prev, size + + """ + r = 0 + p = 1 + for i in range(self.nodes - 3, -1, -1): + r += p*self.prufer_repr[i] + p *= self.nodes + return r + + @classmethod + def unrank(self, rank, n): + """Finds the unranked Prufer sequence. + + Examples + ======== + + >>> from sympy.combinatorics.prufer import Prufer + >>> Prufer.unrank(0, 4) + Prufer([0, 0]) + + """ + n, rank = as_int(n), as_int(rank) + L = defaultdict(int) + for i in range(n - 3, -1, -1): + L[i] = rank % n + rank = (rank - L[i])//n + return Prufer([L[i] for i in range(len(L))]) + + def __new__(cls, *args, **kw_args): + """The constructor for the Prufer object. + + Examples + ======== + + >>> from sympy.combinatorics.prufer import Prufer + + A Prufer object can be constructed from a list of edges: + + >>> a = Prufer([[0, 1], [0, 2], [0, 3]]) + >>> a.prufer_repr + [0, 0] + + If the number of nodes is given, no checking of the nodes will + be performed; it will be assumed that nodes 0 through n - 1 are + present: + + >>> Prufer([[0, 1], [0, 2], [0, 3]], 4) + Prufer([[0, 1], [0, 2], [0, 3]], 4) + + A Prufer object can be constructed from a Prufer sequence: + + >>> b = Prufer([1, 3]) + >>> b.tree_repr + [[0, 1], [1, 3], [2, 3]] + + """ + arg0 = Array(args[0]) if args[0] else Tuple() + args = (arg0,) + tuple(_sympify(arg) for arg in args[1:]) + ret_obj = Basic.__new__(cls, *args, **kw_args) + args = [list(args[0])] + if args[0] and iterable(args[0][0]): + if not args[0][0]: + raise ValueError( + 'Prufer expects at least one edge in the tree.') + if len(args) > 1: + nnodes = args[1] + else: + nodes = set(flatten(args[0])) + nnodes = max(nodes) + 1 + if nnodes != len(nodes): + missing = set(range(nnodes)) - nodes + if len(missing) == 1: + msg = 'Node %s is missing.' % missing.pop() + else: + msg = 'Nodes %s are missing.' % sorted(missing) + raise ValueError(msg) + ret_obj._tree_repr = [list(i) for i in args[0]] + ret_obj._nodes = nnodes + else: + ret_obj._prufer_repr = args[0] + ret_obj._nodes = len(ret_obj._prufer_repr) + 2 + return ret_obj + + def next(self, delta=1): + """Generates the Prufer sequence that is delta beyond the current one. + + Examples + ======== + + >>> from sympy.combinatorics.prufer import Prufer + >>> a = Prufer([[0, 1], [0, 2], [0, 3]]) + >>> b = a.next(1) # == a.next() + >>> b.tree_repr + [[0, 2], [0, 1], [1, 3]] + >>> b.rank + 1 + + See Also + ======== + + prufer_rank, rank, prev, size + + """ + return Prufer.unrank(self.rank + delta, self.nodes) + + def prev(self, delta=1): + """Generates the Prufer sequence that is -delta before the current one. + + Examples + ======== + + >>> from sympy.combinatorics.prufer import Prufer + >>> a = Prufer([[0, 1], [1, 2], [2, 3], [1, 4]]) + >>> a.rank + 36 + >>> b = a.prev() + >>> b + Prufer([1, 2, 0]) + >>> b.rank + 35 + + See Also + ======== + + prufer_rank, rank, next, size + + """ + return Prufer.unrank(self.rank -delta, self.nodes) diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/rewritingsystem.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/rewritingsystem.py new file mode 100644 index 0000000000000000000000000000000000000000..b9e8dfe4c831db6490c987c85a57d0d1a939de61 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/rewritingsystem.py @@ -0,0 +1,453 @@ +from collections import deque +from sympy.combinatorics.rewritingsystem_fsm import StateMachine + +class RewritingSystem: + ''' + A class implementing rewriting systems for `FpGroup`s. + + References + ========== + .. [1] Epstein, D., Holt, D. and Rees, S. (1991). + The use of Knuth-Bendix methods to solve the word problem in automatic groups. + Journal of Symbolic Computation, 12(4-5), pp.397-414. + + .. [2] GAP's Manual on its KBMAG package + https://www.gap-system.org/Manuals/pkg/kbmag-1.5.3/doc/manual.pdf + + ''' + def __init__(self, group): + self.group = group + self.alphabet = group.generators + self._is_confluent = None + + # these values are taken from [2] + self.maxeqns = 32767 # max rules + self.tidyint = 100 # rules before tidying + + # _max_exceeded is True if maxeqns is exceeded + # at any point + self._max_exceeded = False + + # Reduction automaton + self.reduction_automaton = None + self._new_rules = {} + + # dictionary of reductions + self.rules = {} + self.rules_cache = deque([], 50) + self._init_rules() + + + # All the transition symbols in the automaton + generators = list(self.alphabet) + generators += [gen**-1 for gen in generators] + # Create a finite state machine as an instance of the StateMachine object + self.reduction_automaton = StateMachine('Reduction automaton for '+ repr(self.group), generators) + self.construct_automaton() + + def set_max(self, n): + ''' + Set the maximum number of rules that can be defined + + ''' + if n > self.maxeqns: + self._max_exceeded = False + self.maxeqns = n + return + + @property + def is_confluent(self): + ''' + Return `True` if the system is confluent + + ''' + if self._is_confluent is None: + self._is_confluent = self._check_confluence() + return self._is_confluent + + def _init_rules(self): + identity = self.group.free_group.identity + for r in self.group.relators: + self.add_rule(r, identity) + self._remove_redundancies() + return + + def _add_rule(self, r1, r2): + ''' + Add the rule r1 -> r2 with no checking or further + deductions + + ''' + if len(self.rules) + 1 > self.maxeqns: + self._is_confluent = self._check_confluence() + self._max_exceeded = True + raise RuntimeError("Too many rules were defined.") + self.rules[r1] = r2 + # Add the newly added rule to the `new_rules` dictionary. + if self.reduction_automaton: + self._new_rules[r1] = r2 + + def add_rule(self, w1, w2, check=False): + new_keys = set() + + if w1 == w2: + return new_keys + + if w1 < w2: + w1, w2 = w2, w1 + + if (w1, w2) in self.rules_cache: + return new_keys + self.rules_cache.append((w1, w2)) + + s1, s2 = w1, w2 + + # The following is the equivalent of checking + # s1 for overlaps with the implicit reductions + # {g*g**-1 -> } and {g**-1*g -> } + # for any generator g without installing the + # redundant rules that would result from processing + # the overlaps. See [1], Section 3 for details. + + if len(s1) - len(s2) < 3: + if s1 not in self.rules: + new_keys.add(s1) + if not check: + self._add_rule(s1, s2) + if s2**-1 > s1**-1 and s2**-1 not in self.rules: + new_keys.add(s2**-1) + if not check: + self._add_rule(s2**-1, s1**-1) + + # overlaps on the right + while len(s1) - len(s2) > -1: + g = s1[len(s1)-1] + s1 = s1.subword(0, len(s1)-1) + s2 = s2*g**-1 + if len(s1) - len(s2) < 0: + if s2 not in self.rules: + if not check: + self._add_rule(s2, s1) + new_keys.add(s2) + elif len(s1) - len(s2) < 3: + new = self.add_rule(s1, s2, check) + new_keys.update(new) + + # overlaps on the left + while len(w1) - len(w2) > -1: + g = w1[0] + w1 = w1.subword(1, len(w1)) + w2 = g**-1*w2 + if len(w1) - len(w2) < 0: + if w2 not in self.rules: + if not check: + self._add_rule(w2, w1) + new_keys.add(w2) + elif len(w1) - len(w2) < 3: + new = self.add_rule(w1, w2, check) + new_keys.update(new) + + return new_keys + + def _remove_redundancies(self, changes=False): + ''' + Reduce left- and right-hand sides of reduction rules + and remove redundant equations (i.e. those for which + lhs == rhs). If `changes` is `True`, return a set + containing the removed keys and a set containing the + added keys + + ''' + removed = set() + added = set() + rules = self.rules.copy() + for r in rules: + v = self.reduce(r, exclude=r) + w = self.reduce(rules[r]) + if v != r: + del self.rules[r] + removed.add(r) + if v > w: + added.add(v) + self.rules[v] = w + elif v < w: + added.add(w) + self.rules[w] = v + else: + self.rules[v] = w + if changes: + return removed, added + return + + def make_confluent(self, check=False): + ''' + Try to make the system confluent using the Knuth-Bendix + completion algorithm + + ''' + if self._max_exceeded: + return self._is_confluent + lhs = list(self.rules.keys()) + + def _overlaps(r1, r2): + len1 = len(r1) + len2 = len(r2) + result = [] + for j in range(1, len1 + len2): + if (r1.subword(len1 - j, len1 + len2 - j, strict=False) + == r2.subword(j - len1, j, strict=False)): + a = r1.subword(0, len1-j, strict=False) + a = a*r2.subword(0, j-len1, strict=False) + b = r2.subword(j-len1, j, strict=False) + c = r2.subword(j, len2, strict=False) + c = c*r1.subword(len1 + len2 - j, len1, strict=False) + result.append(a*b*c) + return result + + def _process_overlap(w, r1, r2, check): + s = w.eliminate_word(r1, self.rules[r1]) + s = self.reduce(s) + t = w.eliminate_word(r2, self.rules[r2]) + t = self.reduce(t) + if s != t: + if check: + # system not confluent + return [0] + try: + new_keys = self.add_rule(t, s, check) + return new_keys + except RuntimeError: + return False + return + + added = 0 + i = 0 + while i < len(lhs): + r1 = lhs[i] + i += 1 + # j could be i+1 to not + # check each pair twice but lhs + # is extended in the loop and the new + # elements have to be checked with the + # preceding ones. there is probably a better way + # to handle this + j = 0 + while j < len(lhs): + r2 = lhs[j] + j += 1 + if r1 == r2: + continue + overlaps = _overlaps(r1, r2) + overlaps.extend(_overlaps(r1**-1, r2)) + if not overlaps: + continue + for w in overlaps: + new_keys = _process_overlap(w, r1, r2, check) + if new_keys: + if check: + return False + lhs.extend(new_keys) + added += len(new_keys) + elif new_keys == False: + # too many rules were added so the process + # couldn't complete + return self._is_confluent + + if added > self.tidyint and not check: + # tidy up + r, a = self._remove_redundancies(changes=True) + added = 0 + if r: + # reset i since some elements were removed + i = min(lhs.index(s) for s in r) + lhs = [l for l in lhs if l not in r] + lhs.extend(a) + if r1 in r: + # r1 was removed as redundant + break + + self._is_confluent = True + if not check: + self._remove_redundancies() + return True + + def _check_confluence(self): + return self.make_confluent(check=True) + + def reduce(self, word, exclude=None): + ''' + Apply reduction rules to `word` excluding the reduction rule + for the lhs equal to `exclude` + + ''' + rules = {r: self.rules[r] for r in self.rules if r != exclude} + # the following is essentially `eliminate_words()` code from the + # `FreeGroupElement` class, the only difference being the first + # "if" statement + again = True + new = word + while again: + again = False + for r in rules: + prev = new + if rules[r]**-1 > r**-1: + new = new.eliminate_word(r, rules[r], _all=True, inverse=False) + else: + new = new.eliminate_word(r, rules[r], _all=True) + if new != prev: + again = True + return new + + def _compute_inverse_rules(self, rules): + ''' + Compute the inverse rules for a given set of rules. + The inverse rules are used in the automaton for word reduction. + + Arguments: + rules (dictionary): Rules for which the inverse rules are to computed. + + Returns: + Dictionary of inverse_rules. + + ''' + inverse_rules = {} + for r in rules: + rule_key_inverse = r**-1 + rule_value_inverse = (rules[r])**-1 + if (rule_value_inverse < rule_key_inverse): + inverse_rules[rule_key_inverse] = rule_value_inverse + else: + inverse_rules[rule_value_inverse] = rule_key_inverse + return inverse_rules + + def construct_automaton(self): + ''' + Construct the automaton based on the set of reduction rules of the system. + + Automata Design: + The accept states of the automaton are the proper prefixes of the left hand side of the rules. + The complete left hand side of the rules are the dead states of the automaton. + + ''' + self._add_to_automaton(self.rules) + + def _add_to_automaton(self, rules): + ''' + Add new states and transitions to the automaton. + + Summary: + States corresponding to the new rules added to the system are computed and added to the automaton. + Transitions in the previously added states are also modified if necessary. + + Arguments: + rules (dictionary) -- Dictionary of the newly added rules. + + ''' + # Automaton variables + automaton_alphabet = [] + proper_prefixes = {} + + # compute the inverses of all the new rules added + all_rules = rules + inverse_rules = self._compute_inverse_rules(all_rules) + all_rules.update(inverse_rules) + + # Keep track of the accept_states. + accept_states = [] + + for rule in all_rules: + # The symbols present in the new rules are the symbols to be verified at each state. + # computes the automaton_alphabet, as the transitions solely depend upon the new states. + automaton_alphabet += rule.letter_form_elm + # Compute the proper prefixes for every rule. + proper_prefixes[rule] = [] + letter_word_array = list(rule.letter_form_elm) + len_letter_word_array = len(letter_word_array) + for i in range (1, len_letter_word_array): + letter_word_array[i] = letter_word_array[i-1]*letter_word_array[i] + # Add accept states. + elem = letter_word_array[i-1] + if elem not in self.reduction_automaton.states: + self.reduction_automaton.add_state(elem, state_type='a') + accept_states.append(elem) + proper_prefixes[rule] = letter_word_array + # Check for overlaps between dead and accept states. + if rule in accept_states: + self.reduction_automaton.states[rule].state_type = 'd' + self.reduction_automaton.states[rule].rh_rule = all_rules[rule] + accept_states.remove(rule) + # Add dead states + if rule not in self.reduction_automaton.states: + self.reduction_automaton.add_state(rule, state_type='d', rh_rule=all_rules[rule]) + + automaton_alphabet = set(automaton_alphabet) + + # Add new transitions for every state. + for state in self.reduction_automaton.states: + current_state_name = state + current_state_type = self.reduction_automaton.states[state].state_type + # Transitions will be modified only when suffixes of the current_state + # belongs to the proper_prefixes of the new rules. + # The rest are ignored if they cannot lead to a dead state after a finite number of transisitons. + if current_state_type == 's': + for letter in automaton_alphabet: + if letter in self.reduction_automaton.states: + self.reduction_automaton.states[state].add_transition(letter, letter) + else: + self.reduction_automaton.states[state].add_transition(letter, current_state_name) + elif current_state_type == 'a': + # Check if the transition to any new state in possible. + for letter in automaton_alphabet: + _next = current_state_name*letter + while len(_next) and _next not in self.reduction_automaton.states: + _next = _next.subword(1, len(_next)) + if not len(_next): + _next = 'start' + self.reduction_automaton.states[state].add_transition(letter, _next) + + # Add transitions for new states. All symbols used in the automaton are considered here. + # Ignore this if `reduction_automaton.automaton_alphabet` = `automaton_alphabet`. + if len(self.reduction_automaton.automaton_alphabet) != len(automaton_alphabet): + for state in accept_states: + current_state_name = state + for letter in self.reduction_automaton.automaton_alphabet: + _next = current_state_name*letter + while len(_next) and _next not in self.reduction_automaton.states: + _next = _next.subword(1, len(_next)) + if not len(_next): + _next = 'start' + self.reduction_automaton.states[state].add_transition(letter, _next) + + def reduce_using_automaton(self, word): + ''' + Reduce a word using an automaton. + + Summary: + All the symbols of the word are stored in an array and are given as the input to the automaton. + If the automaton reaches a dead state that subword is replaced and the automaton is run from the beginning. + The complete word has to be replaced when the word is read and the automaton reaches a dead state. + So, this process is repeated until the word is read completely and the automaton reaches the accept state. + + Arguments: + word (instance of FreeGroupElement) -- Word that needs to be reduced. + + ''' + # Modify the automaton if new rules are found. + if self._new_rules: + self._add_to_automaton(self._new_rules) + self._new_rules = {} + + flag = 1 + while flag: + flag = 0 + current_state = self.reduction_automaton.states['start'] + for i, s in enumerate(word.letter_form_elm): + next_state_name = current_state.transitions[s] + next_state = self.reduction_automaton.states[next_state_name] + if next_state.state_type == 'd': + subst = next_state.rh_rule + word = word.substituted_word(i - len(next_state_name) + 1, i+1, subst) + flag = 1 + break + current_state = next_state + return word diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/rewritingsystem_fsm.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/rewritingsystem_fsm.py new file mode 100644 index 0000000000000000000000000000000000000000..21916530040ac321180692d1a0811da4ae36a056 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/rewritingsystem_fsm.py @@ -0,0 +1,60 @@ +class State: + ''' + A representation of a state managed by a ``StateMachine``. + + Attributes: + name (instance of FreeGroupElement or string) -- State name which is also assigned to the Machine. + transisitons (OrderedDict) -- Represents all the transitions of the state object. + state_type (string) -- Denotes the type (accept/start/dead) of the state. + rh_rule (instance of FreeGroupElement) -- right hand rule for dead state. + state_machine (instance of StateMachine object) -- The finite state machine that the state belongs to. + ''' + + def __init__(self, name, state_machine, state_type=None, rh_rule=None): + self.name = name + self.transitions = {} + self.state_machine = state_machine + self.state_type = state_type[0] + self.rh_rule = rh_rule + + def add_transition(self, letter, state): + ''' + Add a transition from the current state to a new state. + + Keyword Arguments: + letter -- The alphabet element the current state reads to make the state transition. + state -- This will be an instance of the State object which represents a new state after in the transition after the alphabet is read. + + ''' + self.transitions[letter] = state + +class StateMachine: + ''' + Representation of a finite state machine the manages the states and the transitions of the automaton. + + Attributes: + states (dictionary) -- Collection of all registered `State` objects. + name (str) -- Name of the state machine. + ''' + + def __init__(self, name, automaton_alphabet): + self.name = name + self.automaton_alphabet = automaton_alphabet + self.states = {} # Contains all the states in the machine. + self.add_state('start', state_type='s') + + def add_state(self, state_name, state_type=None, rh_rule=None): + ''' + Instantiate a state object and stores it in the 'states' dictionary. + + Arguments: + state_name (instance of FreeGroupElement or string) -- name of the new states. + state_type (string) -- Denotes the type (accept/start/dead) of the state added. + rh_rule (instance of FreeGroupElement) -- right hand rule for dead state. + + ''' + new_state = State(state_name, self, state_type, rh_rule) + self.states[state_name] = new_state + + def __repr__(self): + return "%s" % (self.name) diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/schur_number.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/schur_number.py new file mode 100644 index 0000000000000000000000000000000000000000..83aac98e543d4b54d4e6af17adca6e4f4de1b9ac --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/schur_number.py @@ -0,0 +1,160 @@ +""" +The Schur number S(k) is the largest integer n for which the interval [1,n] +can be partitioned into k sum-free sets.(https://mathworld.wolfram.com/SchurNumber.html) +""" +import math +from sympy.core import S +from sympy.core.basic import Basic +from sympy.core.function import Function +from sympy.core.numbers import Integer + + +class SchurNumber(Function): + r""" + This function creates a SchurNumber object + which is evaluated for `k \le 5` otherwise only + the lower bound information can be retrieved. + + Examples + ======== + + >>> from sympy.combinatorics.schur_number import SchurNumber + + Since S(3) = 13, hence the output is a number + >>> SchurNumber(3) + 13 + + We do not know the Schur number for values greater than 5, hence + only the object is returned + >>> SchurNumber(6) + SchurNumber(6) + + Now, the lower bound information can be retrieved using lower_bound() + method + >>> SchurNumber(6).lower_bound() + 536 + + """ + + @classmethod + def eval(cls, k): + if k.is_Number: + if k is S.Infinity: + return S.Infinity + if k.is_zero: + return S.Zero + if not k.is_integer or k.is_negative: + raise ValueError("k should be a positive integer") + first_known_schur_numbers = {1: 1, 2: 4, 3: 13, 4: 44, 5: 160} + if k <= 5: + return Integer(first_known_schur_numbers[k]) + + def lower_bound(self): + f_ = self.args[0] + # Improved lower bounds known for S(6) and S(7) + if f_ == 6: + return Integer(536) + if f_ == 7: + return Integer(1680) + # For other cases, use general expression + if f_.is_Integer: + return 3*self.func(f_ - 1).lower_bound() - 1 + return (3**f_ - 1)/2 + + +def _schur_subsets_number(n): + + if n is S.Infinity: + raise ValueError("Input must be finite") + if n <= 0: + raise ValueError("n must be a non-zero positive integer.") + elif n <= 3: + min_k = 1 + else: + min_k = math.ceil(math.log(2*n + 1, 3)) + + return Integer(min_k) + + +def schur_partition(n): + """ + + This function returns the partition in the minimum number of sum-free subsets + according to the lower bound given by the Schur Number. + + Parameters + ========== + + n: a number + n is the upper limit of the range [1, n] for which we need to find and + return the minimum number of free subsets according to the lower bound + of schur number + + Returns + ======= + + List of lists + List of the minimum number of sum-free subsets + + Notes + ===== + + It is possible for some n to make the partition into less + subsets since the only known Schur numbers are: + S(1) = 1, S(2) = 4, S(3) = 13, S(4) = 44. + e.g for n = 44 the lower bound from the function above is 5 subsets but it has been proven + that can be done with 4 subsets. + + Examples + ======== + + For n = 1, 2, 3 the answer is the set itself + + >>> from sympy.combinatorics.schur_number import schur_partition + >>> schur_partition(2) + [[1, 2]] + + For n > 3, the answer is the minimum number of sum-free subsets: + + >>> schur_partition(5) + [[3, 2], [5], [1, 4]] + + >>> schur_partition(8) + [[3, 2], [6, 5, 8], [1, 4, 7]] + """ + + if isinstance(n, Basic) and not n.is_Number: + raise ValueError("Input value must be a number") + + number_of_subsets = _schur_subsets_number(n) + if n == 1: + sum_free_subsets = [[1]] + elif n == 2: + sum_free_subsets = [[1, 2]] + elif n == 3: + sum_free_subsets = [[1, 2, 3]] + else: + sum_free_subsets = [[1, 4], [2, 3]] + + while len(sum_free_subsets) < number_of_subsets: + sum_free_subsets = _generate_next_list(sum_free_subsets, n) + missed_elements = [3*k + 1 for k in range(len(sum_free_subsets), (n-1)//3 + 1)] + sum_free_subsets[-1] += missed_elements + + return sum_free_subsets + + +def _generate_next_list(current_list, n): + new_list = [] + + for item in current_list: + temp_1 = [number*3 for number in item if number*3 <= n] + temp_2 = [number*3 - 1 for number in item if number*3 - 1 <= n] + new_item = temp_1 + temp_2 + new_list.append(new_item) + + last_list = [3*k + 1 for k in range(len(current_list)+1) if 3*k + 1 <= n] + new_list.append(last_list) + current_list = new_list + + return current_list diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/subsets.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/subsets.py new file mode 100644 index 0000000000000000000000000000000000000000..e540cb2395cb27e04c9d513831cb834a05ec2abd --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/subsets.py @@ -0,0 +1,619 @@ +from itertools import combinations + +from sympy.combinatorics.graycode import GrayCode + + +class Subset(): + """ + Represents a basic subset object. + + Explanation + =========== + + We generate subsets using essentially two techniques, + binary enumeration and lexicographic enumeration. + The Subset class takes two arguments, the first one + describes the initial subset to consider and the second + describes the superset. + + Examples + ======== + + >>> from sympy.combinatorics import Subset + >>> a = Subset(['c', 'd'], ['a', 'b', 'c', 'd']) + >>> a.next_binary().subset + ['b'] + >>> a.prev_binary().subset + ['c'] + """ + + _rank_binary = None + _rank_lex = None + _rank_graycode = None + _subset = None + _superset = None + + def __new__(cls, subset, superset): + """ + Default constructor. + + It takes the ``subset`` and its ``superset`` as its parameters. + + Examples + ======== + + >>> from sympy.combinatorics import Subset + >>> a = Subset(['c', 'd'], ['a', 'b', 'c', 'd']) + >>> a.subset + ['c', 'd'] + >>> a.superset + ['a', 'b', 'c', 'd'] + >>> a.size + 2 + """ + if len(subset) > len(superset): + raise ValueError('Invalid arguments have been provided. The ' + 'superset must be larger than the subset.') + for elem in subset: + if elem not in superset: + raise ValueError('The superset provided is invalid as it does ' + 'not contain the element {}'.format(elem)) + obj = object.__new__(cls) + obj._subset = subset + obj._superset = superset + return obj + + def __eq__(self, other): + """Return a boolean indicating whether a == b on the basis of + whether both objects are of the class Subset and if the values + of the subset and superset attributes are the same. + """ + if not isinstance(other, Subset): + return NotImplemented + return self.subset == other.subset and self.superset == other.superset + + def iterate_binary(self, k): + """ + This is a helper function. It iterates over the + binary subsets by ``k`` steps. This variable can be + both positive or negative. + + Examples + ======== + + >>> from sympy.combinatorics import Subset + >>> a = Subset(['c', 'd'], ['a', 'b', 'c', 'd']) + >>> a.iterate_binary(-2).subset + ['d'] + >>> a = Subset(['a', 'b', 'c'], ['a', 'b', 'c', 'd']) + >>> a.iterate_binary(2).subset + [] + + See Also + ======== + + next_binary, prev_binary + """ + bin_list = Subset.bitlist_from_subset(self.subset, self.superset) + n = (int(''.join(bin_list), 2) + k) % 2**self.superset_size + bits = bin(n)[2:].rjust(self.superset_size, '0') + return Subset.subset_from_bitlist(self.superset, bits) + + def next_binary(self): + """ + Generates the next binary ordered subset. + + Examples + ======== + + >>> from sympy.combinatorics import Subset + >>> a = Subset(['c', 'd'], ['a', 'b', 'c', 'd']) + >>> a.next_binary().subset + ['b'] + >>> a = Subset(['a', 'b', 'c', 'd'], ['a', 'b', 'c', 'd']) + >>> a.next_binary().subset + [] + + See Also + ======== + + prev_binary, iterate_binary + """ + return self.iterate_binary(1) + + def prev_binary(self): + """ + Generates the previous binary ordered subset. + + Examples + ======== + + >>> from sympy.combinatorics import Subset + >>> a = Subset([], ['a', 'b', 'c', 'd']) + >>> a.prev_binary().subset + ['a', 'b', 'c', 'd'] + >>> a = Subset(['c', 'd'], ['a', 'b', 'c', 'd']) + >>> a.prev_binary().subset + ['c'] + + See Also + ======== + + next_binary, iterate_binary + """ + return self.iterate_binary(-1) + + def next_lexicographic(self): + """ + Generates the next lexicographically ordered subset. + + Examples + ======== + + >>> from sympy.combinatorics import Subset + >>> a = Subset(['c', 'd'], ['a', 'b', 'c', 'd']) + >>> a.next_lexicographic().subset + ['d'] + >>> a = Subset(['d'], ['a', 'b', 'c', 'd']) + >>> a.next_lexicographic().subset + [] + + See Also + ======== + + prev_lexicographic + """ + i = self.superset_size - 1 + indices = Subset.subset_indices(self.subset, self.superset) + + if i in indices: + if i - 1 in indices: + indices.remove(i - 1) + else: + indices.remove(i) + i = i - 1 + while i >= 0 and i not in indices: + i = i - 1 + if i >= 0: + indices.remove(i) + indices.append(i+1) + else: + while i not in indices and i >= 0: + i = i - 1 + indices.append(i + 1) + + ret_set = [] + super_set = self.superset + for i in indices: + ret_set.append(super_set[i]) + return Subset(ret_set, super_set) + + def prev_lexicographic(self): + """ + Generates the previous lexicographically ordered subset. + + Examples + ======== + + >>> from sympy.combinatorics import Subset + >>> a = Subset([], ['a', 'b', 'c', 'd']) + >>> a.prev_lexicographic().subset + ['d'] + >>> a = Subset(['c','d'], ['a', 'b', 'c', 'd']) + >>> a.prev_lexicographic().subset + ['c'] + + See Also + ======== + + next_lexicographic + """ + i = self.superset_size - 1 + indices = Subset.subset_indices(self.subset, self.superset) + + while i >= 0 and i not in indices: + i = i - 1 + + if i == 0 or i - 1 in indices: + indices.remove(i) + else: + if i >= 0: + indices.remove(i) + indices.append(i - 1) + indices.append(self.superset_size - 1) + + ret_set = [] + super_set = self.superset + for i in indices: + ret_set.append(super_set[i]) + return Subset(ret_set, super_set) + + def iterate_graycode(self, k): + """ + Helper function used for prev_gray and next_gray. + It performs ``k`` step overs to get the respective Gray codes. + + Examples + ======== + + >>> from sympy.combinatorics import Subset + >>> a = Subset([1, 2, 3], [1, 2, 3, 4]) + >>> a.iterate_graycode(3).subset + [1, 4] + >>> a.iterate_graycode(-2).subset + [1, 2, 4] + + See Also + ======== + + next_gray, prev_gray + """ + unranked_code = GrayCode.unrank(self.superset_size, + (self.rank_gray + k) % self.cardinality) + return Subset.subset_from_bitlist(self.superset, + unranked_code) + + def next_gray(self): + """ + Generates the next Gray code ordered subset. + + Examples + ======== + + >>> from sympy.combinatorics import Subset + >>> a = Subset([1, 2, 3], [1, 2, 3, 4]) + >>> a.next_gray().subset + [1, 3] + + See Also + ======== + + iterate_graycode, prev_gray + """ + return self.iterate_graycode(1) + + def prev_gray(self): + """ + Generates the previous Gray code ordered subset. + + Examples + ======== + + >>> from sympy.combinatorics import Subset + >>> a = Subset([2, 3, 4], [1, 2, 3, 4, 5]) + >>> a.prev_gray().subset + [2, 3, 4, 5] + + See Also + ======== + + iterate_graycode, next_gray + """ + return self.iterate_graycode(-1) + + @property + def rank_binary(self): + """ + Computes the binary ordered rank. + + Examples + ======== + + >>> from sympy.combinatorics import Subset + >>> a = Subset([], ['a','b','c','d']) + >>> a.rank_binary + 0 + >>> a = Subset(['c', 'd'], ['a', 'b', 'c', 'd']) + >>> a.rank_binary + 3 + + See Also + ======== + + iterate_binary, unrank_binary + """ + if self._rank_binary is None: + self._rank_binary = int("".join( + Subset.bitlist_from_subset(self.subset, + self.superset)), 2) + return self._rank_binary + + @property + def rank_lexicographic(self): + """ + Computes the lexicographic ranking of the subset. + + Examples + ======== + + >>> from sympy.combinatorics import Subset + >>> a = Subset(['c', 'd'], ['a', 'b', 'c', 'd']) + >>> a.rank_lexicographic + 14 + >>> a = Subset([2, 4, 5], [1, 2, 3, 4, 5, 6]) + >>> a.rank_lexicographic + 43 + """ + if self._rank_lex is None: + def _ranklex(self, subset_index, i, n): + if subset_index == [] or i > n: + return 0 + if i in subset_index: + subset_index.remove(i) + return 1 + _ranklex(self, subset_index, i + 1, n) + return 2**(n - i - 1) + _ranklex(self, subset_index, i + 1, n) + indices = Subset.subset_indices(self.subset, self.superset) + self._rank_lex = _ranklex(self, indices, 0, self.superset_size) + return self._rank_lex + + @property + def rank_gray(self): + """ + Computes the Gray code ranking of the subset. + + Examples + ======== + + >>> from sympy.combinatorics import Subset + >>> a = Subset(['c','d'], ['a','b','c','d']) + >>> a.rank_gray + 2 + >>> a = Subset([2, 4, 5], [1, 2, 3, 4, 5, 6]) + >>> a.rank_gray + 27 + + See Also + ======== + + iterate_graycode, unrank_gray + """ + if self._rank_graycode is None: + bits = Subset.bitlist_from_subset(self.subset, self.superset) + self._rank_graycode = GrayCode(len(bits), start=bits).rank + return self._rank_graycode + + @property + def subset(self): + """ + Gets the subset represented by the current instance. + + Examples + ======== + + >>> from sympy.combinatorics import Subset + >>> a = Subset(['c', 'd'], ['a', 'b', 'c', 'd']) + >>> a.subset + ['c', 'd'] + + See Also + ======== + + superset, size, superset_size, cardinality + """ + return self._subset + + @property + def size(self): + """ + Gets the size of the subset. + + Examples + ======== + + >>> from sympy.combinatorics import Subset + >>> a = Subset(['c', 'd'], ['a', 'b', 'c', 'd']) + >>> a.size + 2 + + See Also + ======== + + subset, superset, superset_size, cardinality + """ + return len(self.subset) + + @property + def superset(self): + """ + Gets the superset of the subset. + + Examples + ======== + + >>> from sympy.combinatorics import Subset + >>> a = Subset(['c', 'd'], ['a', 'b', 'c', 'd']) + >>> a.superset + ['a', 'b', 'c', 'd'] + + See Also + ======== + + subset, size, superset_size, cardinality + """ + return self._superset + + @property + def superset_size(self): + """ + Returns the size of the superset. + + Examples + ======== + + >>> from sympy.combinatorics import Subset + >>> a = Subset(['c', 'd'], ['a', 'b', 'c', 'd']) + >>> a.superset_size + 4 + + See Also + ======== + + subset, superset, size, cardinality + """ + return len(self.superset) + + @property + def cardinality(self): + """ + Returns the number of all possible subsets. + + Examples + ======== + + >>> from sympy.combinatorics import Subset + >>> a = Subset(['c', 'd'], ['a', 'b', 'c', 'd']) + >>> a.cardinality + 16 + + See Also + ======== + + subset, superset, size, superset_size + """ + return 2**(self.superset_size) + + @classmethod + def subset_from_bitlist(self, super_set, bitlist): + """ + Gets the subset defined by the bitlist. + + Examples + ======== + + >>> from sympy.combinatorics import Subset + >>> Subset.subset_from_bitlist(['a', 'b', 'c', 'd'], '0011').subset + ['c', 'd'] + + See Also + ======== + + bitlist_from_subset + """ + if len(super_set) != len(bitlist): + raise ValueError("The sizes of the lists are not equal") + ret_set = [] + for i in range(len(bitlist)): + if bitlist[i] == '1': + ret_set.append(super_set[i]) + return Subset(ret_set, super_set) + + @classmethod + def bitlist_from_subset(self, subset, superset): + """ + Gets the bitlist corresponding to a subset. + + Examples + ======== + + >>> from sympy.combinatorics import Subset + >>> Subset.bitlist_from_subset(['c', 'd'], ['a', 'b', 'c', 'd']) + '0011' + + See Also + ======== + + subset_from_bitlist + """ + bitlist = ['0'] * len(superset) + if isinstance(subset, Subset): + subset = subset.subset + for i in Subset.subset_indices(subset, superset): + bitlist[i] = '1' + return ''.join(bitlist) + + @classmethod + def unrank_binary(self, rank, superset): + """ + Gets the binary ordered subset of the specified rank. + + Examples + ======== + + >>> from sympy.combinatorics import Subset + >>> Subset.unrank_binary(4, ['a', 'b', 'c', 'd']).subset + ['b'] + + See Also + ======== + + iterate_binary, rank_binary + """ + bits = bin(rank)[2:].rjust(len(superset), '0') + return Subset.subset_from_bitlist(superset, bits) + + @classmethod + def unrank_gray(self, rank, superset): + """ + Gets the Gray code ordered subset of the specified rank. + + Examples + ======== + + >>> from sympy.combinatorics import Subset + >>> Subset.unrank_gray(4, ['a', 'b', 'c']).subset + ['a', 'b'] + >>> Subset.unrank_gray(0, ['a', 'b', 'c']).subset + [] + + See Also + ======== + + iterate_graycode, rank_gray + """ + graycode_bitlist = GrayCode.unrank(len(superset), rank) + return Subset.subset_from_bitlist(superset, graycode_bitlist) + + @classmethod + def subset_indices(self, subset, superset): + """Return indices of subset in superset in a list; the list is empty + if all elements of ``subset`` are not in ``superset``. + + Examples + ======== + + >>> from sympy.combinatorics import Subset + >>> superset = [1, 3, 2, 5, 4] + >>> Subset.subset_indices([3, 2, 1], superset) + [1, 2, 0] + >>> Subset.subset_indices([1, 6], superset) + [] + >>> Subset.subset_indices([], superset) + [] + + """ + a, b = superset, subset + sb = set(b) + d = {} + for i, ai in enumerate(a): + if ai in sb: + d[ai] = i + sb.remove(ai) + if not sb: + break + else: + return [] + return [d[bi] for bi in b] + + +def ksubsets(superset, k): + """ + Finds the subsets of size ``k`` in lexicographic order. + + This uses the itertools generator. + + Examples + ======== + + >>> from sympy.combinatorics.subsets import ksubsets + >>> list(ksubsets([1, 2, 3], 2)) + [(1, 2), (1, 3), (2, 3)] + >>> list(ksubsets([1, 2, 3, 4, 5], 2)) + [(1, 2), (1, 3), (1, 4), (1, 5), (2, 3), (2, 4), \ + (2, 5), (3, 4), (3, 5), (4, 5)] + + See Also + ======== + + Subset + """ + return combinations(superset, k) diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/tensor_can.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tensor_can.py new file mode 100644 index 0000000000000000000000000000000000000000..d43e96ed27a21f988aaaa8fd16827987541f7373 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/tensor_can.py @@ -0,0 +1,1189 @@ +from sympy.combinatorics.permutations import Permutation, _af_rmul, \ + _af_invert, _af_new +from sympy.combinatorics.perm_groups import PermutationGroup, _orbit, \ + _orbit_transversal +from sympy.combinatorics.util import _distribute_gens_by_base, \ + _orbits_transversals_from_bsgs + +""" + References for tensor canonicalization: + + [1] R. Portugal "Algorithmic simplification of tensor expressions", + J. Phys. A 32 (1999) 7779-7789 + + [2] R. Portugal, B.F. Svaiter "Group-theoretic Approach for Symbolic + Tensor Manipulation: I. Free Indices" + arXiv:math-ph/0107031v1 + + [3] L.R.U. Manssur, R. Portugal "Group-theoretic Approach for Symbolic + Tensor Manipulation: II. Dummy Indices" + arXiv:math-ph/0107032v1 + + [4] xperm.c part of XPerm written by J. M. Martin-Garcia + http://www.xact.es/index.html +""" + + +def dummy_sgs(dummies, sym, n): + """ + Return the strong generators for dummy indices. + + Parameters + ========== + + dummies : List of dummy indices. + `dummies[2k], dummies[2k+1]` are paired indices. + In base form, the dummy indices are always in + consecutive positions. + sym : symmetry under interchange of contracted dummies:: + * None no symmetry + * 0 commuting + * 1 anticommuting + + n : number of indices + + Examples + ======== + + >>> from sympy.combinatorics.tensor_can import dummy_sgs + >>> dummy_sgs(list(range(2, 8)), 0, 8) + [[0, 1, 3, 2, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 5, 4, 6, 7, 8, 9], + [0, 1, 2, 3, 4, 5, 7, 6, 8, 9], [0, 1, 4, 5, 2, 3, 6, 7, 8, 9], + [0, 1, 2, 3, 6, 7, 4, 5, 8, 9]] + """ + if len(dummies) > n: + raise ValueError("List too large") + res = [] + # exchange of contravariant and covariant indices + if sym is not None: + for j in dummies[::2]: + a = list(range(n + 2)) + if sym == 1: + a[n] = n + 1 + a[n + 1] = n + a[j], a[j + 1] = a[j + 1], a[j] + res.append(a) + # rename dummy indices + for j in dummies[:-3:2]: + a = list(range(n + 2)) + a[j:j + 4] = a[j + 2], a[j + 3], a[j], a[j + 1] + res.append(a) + return res + + +def _min_dummies(dummies, sym, indices): + """ + Return list of minima of the orbits of indices in group of dummies. + See ``double_coset_can_rep`` for the description of ``dummies`` and ``sym``. + ``indices`` is the initial list of dummy indices. + + Examples + ======== + + >>> from sympy.combinatorics.tensor_can import _min_dummies + >>> _min_dummies([list(range(2, 8))], [0], list(range(10))) + [0, 1, 2, 2, 2, 2, 2, 2, 8, 9] + """ + num_types = len(sym) + m = [min(dx) if dx else None for dx in dummies] + res = indices[:] + for i in range(num_types): + for c, i in enumerate(indices): + for j in range(num_types): + if i in dummies[j]: + res[c] = m[j] + break + return res + + +def _trace_S(s, j, b, S_cosets): + """ + Return the representative h satisfying s[h[b]] == j + + If there is not such a representative return None + """ + for h in S_cosets[b]: + if s[h[b]] == j: + return h + return None + + +def _trace_D(gj, p_i, Dxtrav): + """ + Return the representative h satisfying h[gj] == p_i + + If there is not such a representative return None + """ + for h in Dxtrav: + if h[gj] == p_i: + return h + return None + + +def _dumx_remove(dumx, dumx_flat, p0): + """ + remove p0 from dumx + """ + res = [] + for dx in dumx: + if p0 not in dx: + res.append(dx) + continue + k = dx.index(p0) + if k % 2 == 0: + p0_paired = dx[k + 1] + else: + p0_paired = dx[k - 1] + dx.remove(p0) + dx.remove(p0_paired) + dumx_flat.remove(p0) + dumx_flat.remove(p0_paired) + res.append(dx) + + +def transversal2coset(size, base, transversal): + a = [] + j = 0 + for i in range(size): + if i in base: + a.append(sorted(transversal[j].values())) + j += 1 + else: + a.append([list(range(size))]) + j = len(a) - 1 + while a[j] == [list(range(size))]: + j -= 1 + return a[:j + 1] + + +def double_coset_can_rep(dummies, sym, b_S, sgens, S_transversals, g): + r""" + Butler-Portugal algorithm for tensor canonicalization with dummy indices. + + Parameters + ========== + + dummies + list of lists of dummy indices, + one list for each type of index; + the dummy indices are put in order contravariant, covariant + [d0, -d0, d1, -d1, ...]. + + sym + list of the symmetries of the index metric for each type. + + possible symmetries of the metrics + * 0 symmetric + * 1 antisymmetric + * None no symmetry + + b_S + base of a minimal slot symmetry BSGS. + + sgens + generators of the slot symmetry BSGS. + + S_transversals + transversals for the slot BSGS. + + g + permutation representing the tensor. + + Returns + ======= + + Return 0 if the tensor is zero, else return the array form of + the permutation representing the canonical form of the tensor. + + Notes + ===== + + A tensor with dummy indices can be represented in a number + of equivalent ways which typically grows exponentially with + the number of indices. To be able to establish if two tensors + with many indices are equal becomes computationally very slow + in absence of an efficient algorithm. + + The Butler-Portugal algorithm [3] is an efficient algorithm to + put tensors in canonical form, solving the above problem. + + Portugal observed that a tensor can be represented by a permutation, + and that the class of tensors equivalent to it under slot and dummy + symmetries is equivalent to the double coset `D*g*S` + (Note: in this documentation we use the conventions for multiplication + of permutations p, q with (p*q)(i) = p[q[i]] which is opposite + to the one used in the Permutation class) + + Using the algorithm by Butler to find a representative of the + double coset one can find a canonical form for the tensor. + + To see this correspondence, + let `g` be a permutation in array form; a tensor with indices `ind` + (the indices including both the contravariant and the covariant ones) + can be written as + + `t = T(ind[g[0]], \dots, ind[g[n-1]])`, + + where `n = len(ind)`; + `g` has size `n + 2`, the last two indices for the sign of the tensor + (trick introduced in [4]). + + A slot symmetry transformation `s` is a permutation acting on the slots + `t \rightarrow T(ind[(g*s)[0]], \dots, ind[(g*s)[n-1]])` + + A dummy symmetry transformation acts on `ind` + `t \rightarrow T(ind[(d*g)[0]], \dots, ind[(d*g)[n-1]])` + + Being interested only in the transformations of the tensor under + these symmetries, one can represent the tensor by `g`, which transforms + as + + `g -> d*g*s`, so it belongs to the coset `D*g*S`, or in other words + to the set of all permutations allowed by the slot and dummy symmetries. + + Let us explain the conventions by an example. + + Given a tensor `T^{d3 d2 d1}{}_{d1 d2 d3}` with the slot symmetries + `T^{a0 a1 a2 a3 a4 a5} = -T^{a2 a1 a0 a3 a4 a5}` + + `T^{a0 a1 a2 a3 a4 a5} = -T^{a4 a1 a2 a3 a0 a5}` + + and symmetric metric, find the tensor equivalent to it which + is the lowest under the ordering of indices: + lexicographic ordering `d1, d2, d3` and then contravariant + before covariant index; that is the canonical form of the tensor. + + The canonical form is `-T^{d1 d2 d3}{}_{d1 d2 d3}` + obtained using `T^{a0 a1 a2 a3 a4 a5} = -T^{a2 a1 a0 a3 a4 a5}`. + + To convert this problem in the input for this function, + use the following ordering of the index names + (- for covariant for short) `d1, -d1, d2, -d2, d3, -d3` + + `T^{d3 d2 d1}{}_{d1 d2 d3}` corresponds to `g = [4, 2, 0, 1, 3, 5, 6, 7]` + where the last two indices are for the sign + + `sgens = [Permutation(0, 2)(6, 7), Permutation(0, 4)(6, 7)]` + + sgens[0] is the slot symmetry `-(0, 2)` + `T^{a0 a1 a2 a3 a4 a5} = -T^{a2 a1 a0 a3 a4 a5}` + + sgens[1] is the slot symmetry `-(0, 4)` + `T^{a0 a1 a2 a3 a4 a5} = -T^{a4 a1 a2 a3 a0 a5}` + + The dummy symmetry group D is generated by the strong base generators + `[(0, 1), (2, 3), (4, 5), (0, 2)(1, 3), (0, 4)(1, 5)]` + where the first three interchange covariant and contravariant + positions of the same index (d1 <-> -d1) and the last two interchange + the dummy indices themselves (d1 <-> d2). + + The dummy symmetry acts from the left + `d = [1, 0, 2, 3, 4, 5, 6, 7]` exchange `d1 \leftrightarrow -d1` + `T^{d3 d2 d1}{}_{d1 d2 d3} == T^{d3 d2}{}_{d1}{}^{d1}{}_{d2 d3}` + + `g=[4, 2, 0, 1, 3, 5, 6, 7] -> [4, 2, 1, 0, 3, 5, 6, 7] = _af_rmul(d, g)` + which differs from `_af_rmul(g, d)`. + + The slot symmetry acts from the right + `s = [2, 1, 0, 3, 4, 5, 7, 6]` exchanges slots 0 and 2 and changes sign + `T^{d3 d2 d1}{}_{d1 d2 d3} == -T^{d1 d2 d3}{}_{d1 d2 d3}` + + `g=[4,2,0,1,3,5,6,7] -> [0, 2, 4, 1, 3, 5, 7, 6] = _af_rmul(g, s)` + + Example in which the tensor is zero, same slot symmetries as above: + `T^{d2}{}_{d1 d3}{}^{d1 d3}{}_{d2}` + + `= -T^{d3}{}_{d1 d3}{}^{d1 d2}{}_{d2}` under slot symmetry `-(0,4)`; + + `= T_{d3 d1}{}^{d3}{}^{d1 d2}{}_{d2}` under slot symmetry `-(0,2)`; + + `= T^{d3}{}_{d1 d3}{}^{d1 d2}{}_{d2}` symmetric metric; + + `= 0` since two of these lines have tensors differ only for the sign. + + The double coset D*g*S consists of permutations `h = d*g*s` corresponding + to equivalent tensors; if there are two `h` which are the same apart + from the sign, return zero; otherwise + choose as representative the tensor with indices + ordered lexicographically according to `[d1, -d1, d2, -d2, d3, -d3]` + that is ``rep = min(D*g*S) = min([d*g*s for d in D for s in S])`` + + The indices are fixed one by one; first choose the lowest index + for slot 0, then the lowest remaining index for slot 1, etc. + Doing this one obtains a chain of stabilizers + + `S \rightarrow S_{b0} \rightarrow S_{b0,b1} \rightarrow \dots` and + `D \rightarrow D_{p0} \rightarrow D_{p0,p1} \rightarrow \dots` + + where ``[b0, b1, ...] = range(b)`` is a base of the symmetric group; + the strong base `b_S` of S is an ordered sublist of it; + therefore it is sufficient to compute once the + strong base generators of S using the Schreier-Sims algorithm; + the stabilizers of the strong base generators are the + strong base generators of the stabilizer subgroup. + + ``dbase = [p0, p1, ...]`` is not in general in lexicographic order, + so that one must recompute the strong base generators each time; + however this is trivial, there is no need to use the Schreier-Sims + algorithm for D. + + The algorithm keeps a TAB of elements `(s_i, d_i, h_i)` + where `h_i = d_i \times g \times s_i` satisfying `h_i[j] = p_j` for `0 \le j < i` + starting from `s_0 = id, d_0 = id, h_0 = g`. + + The equations `h_0[0] = p_0, h_1[1] = p_1, \dots` are solved in this order, + choosing each time the lowest possible value of p_i + + For `j < i` + `d_i*g*s_i*S_{b_0, \dots, b_{i-1}}*b_j = D_{p_0, \dots, p_{i-1}}*p_j` + so that for dx in `D_{p_0,\dots,p_{i-1}}` and sx in + `S_{base[0], \dots, base[i-1]}` one has `dx*d_i*g*s_i*sx*b_j = p_j` + + Search for dx, sx such that this equation holds for `j = i`; + it can be written as `s_i*sx*b_j = J, dx*d_i*g*J = p_j` + `sx*b_j = s_i**-1*J; sx = trace(s_i**-1, S_{b_0,...,b_{i-1}})` + `dx**-1*p_j = d_i*g*J; dx = trace(d_i*g*J, D_{p_0,...,p_{i-1}})` + + `s_{i+1} = s_i*trace(s_i**-1*J, S_{b_0,...,b_{i-1}})` + `d_{i+1} = trace(d_i*g*J, D_{p_0,...,p_{i-1}})**-1*d_i` + `h_{i+1}*b_i = d_{i+1}*g*s_{i+1}*b_i = p_i` + + `h_n*b_j = p_j` for all j, so that `h_n` is the solution. + + Add the found `(s, d, h)` to TAB1. + + At the end of the iteration sort TAB1 with respect to the `h`; + if there are two consecutive `h` in TAB1 which differ only for the + sign, the tensor is zero, so return 0; + if there are two consecutive `h` which are equal, keep only one. + + Then stabilize the slot generators under `i` and the dummy generators + under `p_i`. + + Assign `TAB = TAB1` at the end of the iteration step. + + At the end `TAB` contains a unique `(s, d, h)`, since all the slots + of the tensor `h` have been fixed to have the minimum value according + to the symmetries. The algorithm returns `h`. + + It is important that the slot BSGS has lexicographic minimal base, + otherwise there is an `i` which does not belong to the slot base + for which `p_i` is fixed by the dummy symmetry only, while `i` + is not invariant from the slot stabilizer, so `p_i` is not in + general the minimal value. + + This algorithm differs slightly from the original algorithm [3]: + the canonical form is minimal lexicographically, and + the BSGS has minimal base under lexicographic order. + Equal tensors `h` are eliminated from TAB. + + + Examples + ======== + + >>> from sympy.combinatorics.permutations import Permutation + >>> from sympy.combinatorics.tensor_can import double_coset_can_rep, get_transversals + >>> gens = [Permutation(x) for x in [[2, 1, 0, 3, 4, 5, 7, 6], [4, 1, 2, 3, 0, 5, 7, 6]]] + >>> base = [0, 2] + >>> g = Permutation([4, 2, 0, 1, 3, 5, 6, 7]) + >>> transversals = get_transversals(base, gens) + >>> double_coset_can_rep([list(range(6))], [0], base, gens, transversals, g) + [0, 1, 2, 3, 4, 5, 7, 6] + + >>> g = Permutation([4, 1, 3, 0, 5, 2, 6, 7]) + >>> double_coset_can_rep([list(range(6))], [0], base, gens, transversals, g) + 0 + """ + size = g.size + g = g.array_form + num_dummies = size - 2 + indices = list(range(num_dummies)) + all_metrics_with_sym = not any(_ is None for _ in sym) + num_types = len(sym) + dumx = dummies[:] + dumx_flat = [] + for dx in dumx: + dumx_flat.extend(dx) + b_S = b_S[:] + sgensx = [h._array_form for h in sgens] + if b_S: + S_transversals = transversal2coset(size, b_S, S_transversals) + # strong generating set for D + dsgsx = [] + for i in range(num_types): + dsgsx.extend(dummy_sgs(dumx[i], sym[i], num_dummies)) + idn = list(range(size)) + # TAB = list of entries (s, d, h) where h = _af_rmuln(d,g,s) + # for short, in the following d*g*s means _af_rmuln(d,g,s) + TAB = [(idn, idn, g)] + for i in range(size - 2): + b = i + testb = b in b_S and sgensx + if testb: + sgensx1 = [_af_new(_) for _ in sgensx] + deltab = _orbit(size, sgensx1, b) + else: + deltab = {b} + # p1 = min(IMAGES) = min(Union D_p*h*deltab for h in TAB) + if all_metrics_with_sym: + md = _min_dummies(dumx, sym, indices) + else: + md = [min(_orbit(size, [_af_new( + ddx) for ddx in dsgsx], ii)) for ii in range(size - 2)] + + p_i = min(min(md[h[x]] for x in deltab) for s, d, h in TAB) + dsgsx1 = [_af_new(_) for _ in dsgsx] + Dxtrav = _orbit_transversal(size, dsgsx1, p_i, False, af=True) \ + if dsgsx else None + if Dxtrav: + Dxtrav = [_af_invert(x) for x in Dxtrav] + # compute the orbit of p_i + for ii in range(num_types): + if p_i in dumx[ii]: + # the orbit is made by all the indices in dum[ii] + if sym[ii] is not None: + deltap = dumx[ii] + else: + # the orbit is made by all the even indices if p_i + # is even, by all the odd indices if p_i is odd + p_i_index = dumx[ii].index(p_i) % 2 + deltap = dumx[ii][p_i_index::2] + break + else: + deltap = [p_i] + TAB1 = [] + while TAB: + s, d, h = TAB.pop() + if min(md[h[x]] for x in deltab) != p_i: + continue + deltab1 = [x for x in deltab if md[h[x]] == p_i] + # NEXT = s*deltab1 intersection (d*g)**-1*deltap + dg = _af_rmul(d, g) + dginv = _af_invert(dg) + sdeltab = [s[x] for x in deltab1] + gdeltap = [dginv[x] for x in deltap] + NEXT = [x for x in sdeltab if x in gdeltap] + # d, s satisfy + # d*g*s*base[i-1] = p_{i-1}; using the stabilizers + # d*g*s*S_{base[0],...,base[i-1]}*base[i-1] = + # D_{p_0,...,p_{i-1}}*p_{i-1} + # so that to find d1, s1 satisfying d1*g*s1*b = p_i + # one can look for dx in D_{p_0,...,p_{i-1}} and + # sx in S_{base[0],...,base[i-1]} + # d1 = dx*d; s1 = s*sx + # d1*g*s1*b = dx*d*g*s*sx*b = p_i + for j in NEXT: + if testb: + # solve s1*b = j with s1 = s*sx for some element sx + # of the stabilizer of ..., base[i-1] + # sx*b = s**-1*j; sx = _trace_S(s, j,...) + # s1 = s*trace_S(s**-1*j,...) + s1 = _trace_S(s, j, b, S_transversals) + if not s1: + continue + else: + s1 = [s[ix] for ix in s1] + else: + s1 = s + # assert s1[b] == j # invariant + # solve d1*g*j = p_i with d1 = dx*d for some element dg + # of the stabilizer of ..., p_{i-1} + # dx**-1*p_i = d*g*j; dx**-1 = trace_D(d*g*j,...) + # d1 = trace_D(d*g*j,...)**-1*d + # to save an inversion in the inner loop; notice we did + # Dxtrav = [perm_af_invert(x) for x in Dxtrav] out of the loop + if Dxtrav: + d1 = _trace_D(dg[j], p_i, Dxtrav) + if not d1: + continue + else: + if p_i != dg[j]: + continue + d1 = idn + assert d1[dg[j]] == p_i # invariant + d1 = [d1[ix] for ix in d] + h1 = [d1[g[ix]] for ix in s1] + # assert h1[b] == p_i # invariant + TAB1.append((s1, d1, h1)) + + # if TAB contains equal permutations, keep only one of them; + # if TAB contains equal permutations up to the sign, return 0 + TAB1.sort(key=lambda x: x[-1]) + prev = [0] * size + while TAB1: + s, d, h = TAB1.pop() + if h[:-2] == prev[:-2]: + if h[-1] != prev[-1]: + return 0 + else: + TAB.append((s, d, h)) + prev = h + + # stabilize the SGS + sgensx = [h for h in sgensx if h[b] == b] + if b in b_S: + b_S.remove(b) + _dumx_remove(dumx, dumx_flat, p_i) + dsgsx = [] + for i in range(num_types): + dsgsx.extend(dummy_sgs(dumx[i], sym[i], num_dummies)) + return TAB[0][-1] + + +def canonical_free(base, gens, g, num_free): + """ + Canonicalization of a tensor with respect to free indices + choosing the minimum with respect to lexicographical ordering + in the free indices. + + Explanation + =========== + + ``base``, ``gens`` BSGS for slot permutation group + ``g`` permutation representing the tensor + ``num_free`` number of free indices + The indices must be ordered with first the free indices + + See explanation in double_coset_can_rep + The algorithm is a variation of the one given in [2]. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> from sympy.combinatorics.tensor_can import canonical_free + >>> gens = [[1, 0, 2, 3, 5, 4], [2, 3, 0, 1, 4, 5],[0, 1, 3, 2, 5, 4]] + >>> gens = [Permutation(h) for h in gens] + >>> base = [0, 2] + >>> g = Permutation([2, 1, 0, 3, 4, 5]) + >>> canonical_free(base, gens, g, 4) + [0, 3, 1, 2, 5, 4] + + Consider the product of Riemann tensors + ``T = R^{a}_{d0}^{d1,d2}*R_{d2,d1}^{d0,b}`` + The order of the indices is ``[a, b, d0, -d0, d1, -d1, d2, -d2]`` + The permutation corresponding to the tensor is + ``g = [0, 3, 4, 6, 7, 5, 2, 1, 8, 9]`` + + In particular ``a`` is position ``0``, ``b`` is in position ``9``. + Use the slot symmetries to get `T` is a form which is the minimal + in lexicographic order in the free indices ``a`` and ``b``, e.g. + ``-R^{a}_{d0}^{d1,d2}*R^{b,d0}_{d2,d1}`` corresponding to + ``[0, 3, 4, 6, 1, 2, 7, 5, 9, 8]`` + + >>> from sympy.combinatorics.tensor_can import riemann_bsgs, tensor_gens + >>> base, gens = riemann_bsgs + >>> size, sbase, sgens = tensor_gens(base, gens, [[], []], 0) + >>> g = Permutation([0, 3, 4, 6, 7, 5, 2, 1, 8, 9]) + >>> canonical_free(sbase, [Permutation(h) for h in sgens], g, 2) + [0, 3, 4, 6, 1, 2, 7, 5, 9, 8] + """ + g = g.array_form + size = len(g) + if not base: + return g[:] + + transversals = get_transversals(base, gens) + for x in sorted(g[:-2]): + if x not in base: + base.append(x) + h = g + for transv in transversals: + h_i = [size]*num_free + # find the element s in transversals[i] such that + # _af_rmul(h, s) has its free elements with the lowest position in h + s = None + for sk in transv.values(): + h1 = _af_rmul(h, sk) + hi = [h1.index(ix) for ix in range(num_free)] + if hi < h_i: + h_i = hi + s = sk + if s: + h = _af_rmul(h, s) + return h + + +def _get_map_slots(size, fixed_slots): + res = list(range(size)) + pos = 0 + for i in range(size): + if i in fixed_slots: + continue + res[i] = pos + pos += 1 + return res + + +def _lift_sgens(size, fixed_slots, free, s): + a = [] + j = k = 0 + fd = [y for _, y in sorted(zip(fixed_slots, free))] + num_free = len(free) + for i in range(size): + if i in fixed_slots: + a.append(fd[k]) + k += 1 + else: + a.append(s[j] + num_free) + j += 1 + return a + + +def canonicalize(g, dummies, msym, *v): + """ + canonicalize tensor formed by tensors + + Parameters + ========== + + g : permutation representing the tensor + + dummies : list representing the dummy indices + it can be a list of dummy indices of the same type + or a list of lists of dummy indices, one list for each + type of index; + the dummy indices must come after the free indices, + and put in order contravariant, covariant + [d0, -d0, d1,-d1,...] + + msym : symmetry of the metric(s) + it can be an integer or a list; + in the first case it is the symmetry of the dummy index metric; + in the second case it is the list of the symmetries of the + index metric for each type + + v : list, (base_i, gens_i, n_i, sym_i) for tensors of type `i` + + base_i, gens_i : BSGS for tensors of this type. + The BSGS should have minimal base under lexicographic ordering; + if not, an attempt is made do get the minimal BSGS; + in case of failure, + canonicalize_naive is used, which is much slower. + + n_i : number of tensors of type `i`. + + sym_i : symmetry under exchange of component tensors of type `i`. + + Both for msym and sym_i the cases are + * None no symmetry + * 0 commuting + * 1 anticommuting + + Returns + ======= + + 0 if the tensor is zero, else return the array form of + the permutation representing the canonical form of the tensor. + + Algorithm + ========= + + First one uses canonical_free to get the minimum tensor under + lexicographic order, using only the slot symmetries. + If the component tensors have not minimal BSGS, it is attempted + to find it; if the attempt fails canonicalize_naive + is used instead. + + Compute the residual slot symmetry keeping fixed the free indices + using tensor_gens(base, gens, list_free_indices, sym). + + Reduce the problem eliminating the free indices. + + Then use double_coset_can_rep and lift back the result reintroducing + the free indices. + + Examples + ======== + + one type of index with commuting metric; + + `A_{a b}` and `B_{a b}` antisymmetric and commuting + + `T = A_{d0 d1} * B^{d0}{}_{d2} * B^{d2 d1}` + + `ord = [d0,-d0,d1,-d1,d2,-d2]` order of the indices + + g = [1, 3, 0, 5, 4, 2, 6, 7] + + `T_c = 0` + + >>> from sympy.combinatorics.tensor_can import get_symmetric_group_sgs, canonicalize, bsgs_direct_product + >>> from sympy.combinatorics import Permutation + >>> base2a, gens2a = get_symmetric_group_sgs(2, 1) + >>> t0 = (base2a, gens2a, 1, 0) + >>> t1 = (base2a, gens2a, 2, 0) + >>> g = Permutation([1, 3, 0, 5, 4, 2, 6, 7]) + >>> canonicalize(g, range(6), 0, t0, t1) + 0 + + same as above, but with `B_{a b}` anticommuting + + `T_c = -A^{d0 d1} * B_{d0}{}^{d2} * B_{d1 d2}` + + can = [0,2,1,4,3,5,7,6] + + >>> t1 = (base2a, gens2a, 2, 1) + >>> canonicalize(g, range(6), 0, t0, t1) + [0, 2, 1, 4, 3, 5, 7, 6] + + two types of indices `[a,b,c,d,e,f]` and `[m,n]`, in this order, + both with commuting metric + + `f^{a b c}` antisymmetric, commuting + + `A_{m a}` no symmetry, commuting + + `T = f^c{}_{d a} * f^f{}_{e b} * A_m{}^d * A^{m b} * A_n{}^a * A^{n e}` + + ord = [c,f,a,-a,b,-b,d,-d,e,-e,m,-m,n,-n] + + g = [0,7,3, 1,9,5, 11,6, 10,4, 13,2, 12,8, 14,15] + + The canonical tensor is + `T_c = -f^{c a b} * f^{f d e} * A^m{}_a * A_{m d} * A^n{}_b * A_{n e}` + + can = [0,2,4, 1,6,8, 10,3, 11,7, 12,5, 13,9, 15,14] + + >>> base_f, gens_f = get_symmetric_group_sgs(3, 1) + >>> base1, gens1 = get_symmetric_group_sgs(1) + >>> base_A, gens_A = bsgs_direct_product(base1, gens1, base1, gens1) + >>> t0 = (base_f, gens_f, 2, 0) + >>> t1 = (base_A, gens_A, 4, 0) + >>> dummies = [range(2, 10), range(10, 14)] + >>> g = Permutation([0, 7, 3, 1, 9, 5, 11, 6, 10, 4, 13, 2, 12, 8, 14, 15]) + >>> canonicalize(g, dummies, [0, 0], t0, t1) + [0, 2, 4, 1, 6, 8, 10, 3, 11, 7, 12, 5, 13, 9, 15, 14] + """ + from sympy.combinatorics.testutil import canonicalize_naive + if not isinstance(msym, list): + if msym not in (0, 1, None): + raise ValueError('msym must be 0, 1 or None') + num_types = 1 + else: + num_types = len(msym) + if not all(msymx in (0, 1, None) for msymx in msym): + raise ValueError('msym entries must be 0, 1 or None') + if len(dummies) != num_types: + raise ValueError( + 'dummies and msym must have the same number of elements') + size = g.size + num_tensors = 0 + v1 = [] + for base_i, gens_i, n_i, sym_i in v: + # check that the BSGS is minimal; + # this property is used in double_coset_can_rep; + # if it is not minimal use canonicalize_naive + if not _is_minimal_bsgs(base_i, gens_i): + mbsgs = get_minimal_bsgs(base_i, gens_i) + if not mbsgs: + can = canonicalize_naive(g, dummies, msym, *v) + return can + base_i, gens_i = mbsgs + v1.append((base_i, gens_i, [[]] * n_i, sym_i)) + num_tensors += n_i + + if num_types == 1 and not isinstance(msym, list): + dummies = [dummies] + msym = [msym] + flat_dummies = [] + for dumx in dummies: + flat_dummies.extend(dumx) + + if flat_dummies and flat_dummies != list(range(flat_dummies[0], flat_dummies[-1] + 1)): + raise ValueError('dummies is not valid') + + # slot symmetry of the tensor + size1, sbase, sgens = gens_products(*v1) + if size != size1: + raise ValueError( + 'g has size %d, generators have size %d' % (size, size1)) + free = [i for i in range(size - 2) if i not in flat_dummies] + num_free = len(free) + + # g1 minimal tensor under slot symmetry + g1 = canonical_free(sbase, sgens, g, num_free) + if not flat_dummies: + return g1 + # save the sign of g1 + sign = 0 if g1[-1] == size - 1 else 1 + + # the free indices are kept fixed. + # Determine free_i, the list of slots of tensors which are fixed + # since they are occupied by free indices, which are fixed. + start = 0 + for i, (base_i, gens_i, n_i, sym_i) in enumerate(v): + free_i = [] + len_tens = gens_i[0].size - 2 + # for each component tensor get a list od fixed islots + for j in range(n_i): + # get the elements corresponding to the component tensor + h = g1[start:(start + len_tens)] + fr = [] + # get the positions of the fixed elements in h + for k in free: + if k in h: + fr.append(h.index(k)) + free_i.append(fr) + start += len_tens + v1[i] = (base_i, gens_i, free_i, sym_i) + # BSGS of the tensor with fixed free indices + # if tensor_gens fails in gens_product, use canonicalize_naive + size, sbase, sgens = gens_products(*v1) + + # reduce the permutations getting rid of the free indices + pos_free = [g1.index(x) for x in range(num_free)] + size_red = size - num_free + g1_red = [x - num_free for x in g1 if x in flat_dummies] + if sign: + g1_red.extend([size_red - 1, size_red - 2]) + else: + g1_red.extend([size_red - 2, size_red - 1]) + map_slots = _get_map_slots(size, pos_free) + sbase_red = [map_slots[i] for i in sbase if i not in pos_free] + sgens_red = [_af_new([map_slots[i] for i in y._array_form if i not in pos_free]) for y in sgens] + dummies_red = [[x - num_free for x in y] for y in dummies] + transv_red = get_transversals(sbase_red, sgens_red) + g1_red = _af_new(g1_red) + g2 = double_coset_can_rep( + dummies_red, msym, sbase_red, sgens_red, transv_red, g1_red) + if g2 == 0: + return 0 + # lift to the case with the free indices + g3 = _lift_sgens(size, pos_free, free, g2) + return g3 + + +def perm_af_direct_product(gens1, gens2, signed=True): + """ + Direct products of the generators gens1 and gens2. + + Examples + ======== + + >>> from sympy.combinatorics.tensor_can import perm_af_direct_product + >>> gens1 = [[1, 0, 2, 3], [0, 1, 3, 2]] + >>> gens2 = [[1, 0]] + >>> perm_af_direct_product(gens1, gens2, False) + [[1, 0, 2, 3, 4, 5], [0, 1, 3, 2, 4, 5], [0, 1, 2, 3, 5, 4]] + >>> gens1 = [[1, 0, 2, 3, 5, 4], [0, 1, 3, 2, 4, 5]] + >>> gens2 = [[1, 0, 2, 3]] + >>> perm_af_direct_product(gens1, gens2, True) + [[1, 0, 2, 3, 4, 5, 7, 6], [0, 1, 3, 2, 4, 5, 6, 7], [0, 1, 2, 3, 5, 4, 6, 7]] + """ + gens1 = [list(x) for x in gens1] + gens2 = [list(x) for x in gens2] + s = 2 if signed else 0 + n1 = len(gens1[0]) - s + n2 = len(gens2[0]) - s + start = list(range(n1)) + end = list(range(n1, n1 + n2)) + if signed: + gens1 = [gen[:-2] + end + [gen[-2] + n2, gen[-1] + n2] + for gen in gens1] + gens2 = [start + [x + n1 for x in gen] for gen in gens2] + else: + gens1 = [gen + end for gen in gens1] + gens2 = [start + [x + n1 for x in gen] for gen in gens2] + + res = gens1 + gens2 + + return res + + +def bsgs_direct_product(base1, gens1, base2, gens2, signed=True): + """ + Direct product of two BSGS. + + Parameters + ========== + + base1 : base of the first BSGS. + + gens1 : strong generating sequence of the first BSGS. + + base2, gens2 : similarly for the second BSGS. + + signed : flag for signed permutations. + + Examples + ======== + + >>> from sympy.combinatorics.tensor_can import (get_symmetric_group_sgs, bsgs_direct_product) + >>> base1, gens1 = get_symmetric_group_sgs(1) + >>> base2, gens2 = get_symmetric_group_sgs(2) + >>> bsgs_direct_product(base1, gens1, base2, gens2) + ([1], [(4)(1 2)]) + """ + s = 2 if signed else 0 + n1 = gens1[0].size - s + base = list(base1) + base += [x + n1 for x in base2] + gens1 = [h._array_form for h in gens1] + gens2 = [h._array_form for h in gens2] + gens = perm_af_direct_product(gens1, gens2, signed) + size = len(gens[0]) + id_af = list(range(size)) + gens = [h for h in gens if h != id_af] + if not gens: + gens = [id_af] + return base, [_af_new(h) for h in gens] + + +def get_symmetric_group_sgs(n, antisym=False): + """ + Return base, gens of the minimal BSGS for (anti)symmetric tensor + + Parameters + ========== + + n : rank of the tensor + antisym : bool + ``antisym = False`` symmetric tensor + ``antisym = True`` antisymmetric tensor + + Examples + ======== + + >>> from sympy.combinatorics.tensor_can import get_symmetric_group_sgs + >>> get_symmetric_group_sgs(3) + ([0, 1], [(4)(0 1), (4)(1 2)]) + """ + if n == 1: + return [], [_af_new(list(range(3)))] + gens = [Permutation(n - 1)(i, i + 1)._array_form for i in range(n - 1)] + if antisym == 0: + gens = [x + [n, n + 1] for x in gens] + else: + gens = [x + [n + 1, n] for x in gens] + base = list(range(n - 1)) + return base, [_af_new(h) for h in gens] + +riemann_bsgs = [0, 2], [Permutation(0, 1)(4, 5), Permutation(2, 3)(4, 5), + Permutation(5)(0, 2)(1, 3)] + + +def get_transversals(base, gens): + """ + Return transversals for the group with BSGS base, gens + """ + if not base: + return [] + stabs = _distribute_gens_by_base(base, gens) + orbits, transversals = _orbits_transversals_from_bsgs(base, stabs) + transversals = [{x: h._array_form for x, h in y.items()} for y in + transversals] + return transversals + + +def _is_minimal_bsgs(base, gens): + """ + Check if the BSGS has minimal base under lexigographic order. + + base, gens BSGS + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> from sympy.combinatorics.tensor_can import riemann_bsgs, _is_minimal_bsgs + >>> _is_minimal_bsgs(*riemann_bsgs) + True + >>> riemann_bsgs1 = ([2, 0], ([Permutation(5)(0, 1)(4, 5), Permutation(5)(0, 2)(1, 3)])) + >>> _is_minimal_bsgs(*riemann_bsgs1) + False + """ + base1 = [] + sgs1 = gens[:] + size = gens[0].size + for i in range(size): + if not all(h._array_form[i] == i for h in sgs1): + base1.append(i) + sgs1 = [h for h in sgs1 if h._array_form[i] == i] + return base1 == base + + +def get_minimal_bsgs(base, gens): + """ + Compute a minimal GSGS + + base, gens BSGS + + If base, gens is a minimal BSGS return it; else return a minimal BSGS + if it fails in finding one, it returns None + + TODO: use baseswap in the case in which if it fails in finding a + minimal BSGS + + Examples + ======== + + >>> from sympy.combinatorics import Permutation + >>> from sympy.combinatorics.tensor_can import get_minimal_bsgs + >>> riemann_bsgs1 = ([2, 0], ([Permutation(5)(0, 1)(4, 5), Permutation(5)(0, 2)(1, 3)])) + >>> get_minimal_bsgs(*riemann_bsgs1) + ([0, 2], [(0 1)(4 5), (5)(0 2)(1 3), (2 3)(4 5)]) + """ + G = PermutationGroup(gens) + base, gens = G.schreier_sims_incremental() + if not _is_minimal_bsgs(base, gens): + return None + return base, gens + + +def tensor_gens(base, gens, list_free_indices, sym=0): + """ + Returns size, res_base, res_gens BSGS for n tensors of the + same type. + + Explanation + =========== + + base, gens BSGS for tensors of this type + list_free_indices list of the slots occupied by fixed indices + for each of the tensors + + sym symmetry under commutation of two tensors + sym None no symmetry + sym 0 commuting + sym 1 anticommuting + + Examples + ======== + + >>> from sympy.combinatorics.tensor_can import tensor_gens, get_symmetric_group_sgs + + two symmetric tensors with 3 indices without free indices + + >>> base, gens = get_symmetric_group_sgs(3) + >>> tensor_gens(base, gens, [[], []]) + (8, [0, 1, 3, 4], [(7)(0 1), (7)(1 2), (7)(3 4), (7)(4 5), (7)(0 3)(1 4)(2 5)]) + + two symmetric tensors with 3 indices with free indices in slot 1 and 0 + + >>> tensor_gens(base, gens, [[1], [0]]) + (8, [0, 4], [(7)(0 2), (7)(4 5)]) + + four symmetric tensors with 3 indices, two of which with free indices + + """ + def _get_bsgs(G, base, gens, free_indices): + """ + return the BSGS for G.pointwise_stabilizer(free_indices) + """ + if not free_indices: + return base[:], gens[:] + else: + H = G.pointwise_stabilizer(free_indices) + base, sgs = H.schreier_sims_incremental() + return base, sgs + + # if not base there is no slot symmetry for the component tensors + # if list_free_indices.count([]) < 2 there is no commutation symmetry + # so there is no resulting slot symmetry + if not base and list_free_indices.count([]) < 2: + n = len(list_free_indices) + size = gens[0].size + size = n * (size - 2) + 2 + return size, [], [_af_new(list(range(size)))] + + # if any(list_free_indices) one needs to compute the pointwise + # stabilizer, so G is needed + if any(list_free_indices): + G = PermutationGroup(gens) + else: + G = None + + # no_free list of lists of indices for component tensors without fixed + # indices + no_free = [] + size = gens[0].size + id_af = list(range(size)) + num_indices = size - 2 + if not list_free_indices[0]: + no_free.append(list(range(num_indices))) + res_base, res_gens = _get_bsgs(G, base, gens, list_free_indices[0]) + for i in range(1, len(list_free_indices)): + base1, gens1 = _get_bsgs(G, base, gens, list_free_indices[i]) + res_base, res_gens = bsgs_direct_product(res_base, res_gens, + base1, gens1, 1) + if not list_free_indices[i]: + no_free.append(list(range(size - 2, size - 2 + num_indices))) + size += num_indices + nr = size - 2 + res_gens = [h for h in res_gens if h._array_form != id_af] + # if sym there are no commuting tensors stop here + if sym is None or not no_free: + if not res_gens: + res_gens = [_af_new(id_af)] + return size, res_base, res_gens + + # if the component tensors have moinimal BSGS, so is their direct + # product P; the slot symmetry group is S = P*C, where C is the group + # to (anti)commute the component tensors with no free indices + # a stabilizer has the property S_i = P_i*C_i; + # the BSGS of P*C has SGS_P + SGS_C and the base is + # the ordered union of the bases of P and C. + # If P has minimal BSGS, so has S with this base. + base_comm = [] + for i in range(len(no_free) - 1): + ind1 = no_free[i] + ind2 = no_free[i + 1] + a = list(range(ind1[0])) + a.extend(ind2) + a.extend(ind1) + base_comm.append(ind1[0]) + a.extend(list(range(ind2[-1] + 1, nr))) + if sym == 0: + a.extend([nr, nr + 1]) + else: + a.extend([nr + 1, nr]) + res_gens.append(_af_new(a)) + res_base = list(res_base) + # each base is ordered; order the union of the two bases + for i in base_comm: + if i not in res_base: + res_base.append(i) + res_base.sort() + if not res_gens: + res_gens = [_af_new(id_af)] + + return size, res_base, res_gens + + +def gens_products(*v): + """ + Returns size, res_base, res_gens BSGS for n tensors of different types. + + Explanation + =========== + + v is a sequence of (base_i, gens_i, free_i, sym_i) + where + base_i, gens_i BSGS of tensor of type `i` + free_i list of the fixed slots for each of the tensors + of type `i`; if there are `n_i` tensors of type `i` + and none of them have fixed slots, `free = [[]]*n_i` + sym 0 (1) if the tensors of type `i` (anti)commute among themselves + + Examples + ======== + + >>> from sympy.combinatorics.tensor_can import get_symmetric_group_sgs, gens_products + >>> base, gens = get_symmetric_group_sgs(2) + >>> gens_products((base, gens, [[], []], 0)) + (6, [0, 2], [(5)(0 1), (5)(2 3), (5)(0 2)(1 3)]) + >>> gens_products((base, gens, [[1], []], 0)) + (6, [2], [(5)(2 3)]) + """ + res_size, res_base, res_gens = tensor_gens(*v[0]) + for i in range(1, len(v)): + size, base, gens = tensor_gens(*v[i]) + res_base, res_gens = bsgs_direct_product(res_base, res_gens, base, + gens, 1) + res_size = res_gens[0].size + id_af = list(range(res_size)) + res_gens = [h for h in res_gens if h != id_af] + if not res_gens: + res_gens = [id_af] + return res_size, res_base, res_gens diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/testutil.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/testutil.py new file mode 100644 index 0000000000000000000000000000000000000000..9fe664ce9e7437b97feec90f2052eb2987c57a4e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/testutil.py @@ -0,0 +1,357 @@ +from sympy.combinatorics import Permutation +from sympy.combinatorics.util import _distribute_gens_by_base + +rmul = Permutation.rmul + + +def _cmp_perm_lists(first, second): + """ + Compare two lists of permutations as sets. + + Explanation + =========== + + This is used for testing purposes. Since the array form of a + permutation is currently a list, Permutation is not hashable + and cannot be put into a set. + + Examples + ======== + + >>> from sympy.combinatorics.permutations import Permutation + >>> from sympy.combinatorics.testutil import _cmp_perm_lists + >>> a = Permutation([0, 2, 3, 4, 1]) + >>> b = Permutation([1, 2, 0, 4, 3]) + >>> c = Permutation([3, 4, 0, 1, 2]) + >>> ls1 = [a, b, c] + >>> ls2 = [b, c, a] + >>> _cmp_perm_lists(ls1, ls2) + True + + """ + return {tuple(a) for a in first} == \ + {tuple(a) for a in second} + + +def _naive_list_centralizer(self, other, af=False): + from sympy.combinatorics.perm_groups import PermutationGroup + """ + Return a list of elements for the centralizer of a subgroup/set/element. + + Explanation + =========== + + This is a brute force implementation that goes over all elements of the + group and checks for membership in the centralizer. It is used to + test ``.centralizer()`` from ``sympy.combinatorics.perm_groups``. + + Examples + ======== + + >>> from sympy.combinatorics.testutil import _naive_list_centralizer + >>> from sympy.combinatorics.named_groups import DihedralGroup + >>> D = DihedralGroup(4) + >>> _naive_list_centralizer(D, D) + [Permutation([0, 1, 2, 3]), Permutation([2, 3, 0, 1])] + + See Also + ======== + + sympy.combinatorics.perm_groups.centralizer + + """ + from sympy.combinatorics.permutations import _af_commutes_with + if hasattr(other, 'generators'): + elements = list(self.generate_dimino(af=True)) + gens = [x._array_form for x in other.generators] + commutes_with_gens = lambda x: all(_af_commutes_with(x, gen) for gen in gens) + centralizer_list = [] + if not af: + for element in elements: + if commutes_with_gens(element): + centralizer_list.append(Permutation._af_new(element)) + else: + for element in elements: + if commutes_with_gens(element): + centralizer_list.append(element) + return centralizer_list + elif hasattr(other, 'getitem'): + return _naive_list_centralizer(self, PermutationGroup(other), af) + elif hasattr(other, 'array_form'): + return _naive_list_centralizer(self, PermutationGroup([other]), af) + + +def _verify_bsgs(group, base, gens): + """ + Verify the correctness of a base and strong generating set. + + Explanation + =========== + + This is a naive implementation using the definition of a base and a strong + generating set relative to it. There are other procedures for + verifying a base and strong generating set, but this one will + serve for more robust testing. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import AlternatingGroup + >>> from sympy.combinatorics.testutil import _verify_bsgs + >>> A = AlternatingGroup(4) + >>> A.schreier_sims() + >>> _verify_bsgs(A, A.base, A.strong_gens) + True + + See Also + ======== + + sympy.combinatorics.perm_groups.PermutationGroup.schreier_sims + + """ + from sympy.combinatorics.perm_groups import PermutationGroup + strong_gens_distr = _distribute_gens_by_base(base, gens) + current_stabilizer = group + for i in range(len(base)): + candidate = PermutationGroup(strong_gens_distr[i]) + if current_stabilizer.order() != candidate.order(): + return False + current_stabilizer = current_stabilizer.stabilizer(base[i]) + if current_stabilizer.order() != 1: + return False + return True + + +def _verify_centralizer(group, arg, centr=None): + """ + Verify the centralizer of a group/set/element inside another group. + + This is used for testing ``.centralizer()`` from + ``sympy.combinatorics.perm_groups`` + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import (SymmetricGroup, + ... AlternatingGroup) + >>> from sympy.combinatorics.perm_groups import PermutationGroup + >>> from sympy.combinatorics.permutations import Permutation + >>> from sympy.combinatorics.testutil import _verify_centralizer + >>> S = SymmetricGroup(5) + >>> A = AlternatingGroup(5) + >>> centr = PermutationGroup([Permutation([0, 1, 2, 3, 4])]) + >>> _verify_centralizer(S, A, centr) + True + + See Also + ======== + + _naive_list_centralizer, + sympy.combinatorics.perm_groups.PermutationGroup.centralizer, + _cmp_perm_lists + + """ + if centr is None: + centr = group.centralizer(arg) + centr_list = list(centr.generate_dimino(af=True)) + centr_list_naive = _naive_list_centralizer(group, arg, af=True) + return _cmp_perm_lists(centr_list, centr_list_naive) + + +def _verify_normal_closure(group, arg, closure=None): + from sympy.combinatorics.perm_groups import PermutationGroup + """ + Verify the normal closure of a subgroup/subset/element in a group. + + This is used to test + sympy.combinatorics.perm_groups.PermutationGroup.normal_closure + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import (SymmetricGroup, + ... AlternatingGroup) + >>> from sympy.combinatorics.testutil import _verify_normal_closure + >>> S = SymmetricGroup(3) + >>> A = AlternatingGroup(3) + >>> _verify_normal_closure(S, A, closure=A) + True + + See Also + ======== + + sympy.combinatorics.perm_groups.PermutationGroup.normal_closure + + """ + if closure is None: + closure = group.normal_closure(arg) + conjugates = set() + if hasattr(arg, 'generators'): + subgr_gens = arg.generators + elif hasattr(arg, '__getitem__'): + subgr_gens = arg + elif hasattr(arg, 'array_form'): + subgr_gens = [arg] + for el in group.generate_dimino(): + conjugates.update(gen ^ el for gen in subgr_gens) + naive_closure = PermutationGroup(list(conjugates)) + return closure.is_subgroup(naive_closure) + + +def canonicalize_naive(g, dummies, sym, *v): + """ + Canonicalize tensor formed by tensors of the different types. + + Explanation + =========== + + sym_i symmetry under exchange of two component tensors of type `i` + None no symmetry + 0 commuting + 1 anticommuting + + Parameters + ========== + + g : Permutation representing the tensor. + dummies : List of dummy indices. + msym : Symmetry of the metric. + v : A list of (base_i, gens_i, n_i, sym_i) for tensors of type `i`. + base_i, gens_i BSGS for tensors of this type + n_i number of tensors of type `i` + + Returns + ======= + + Returns 0 if the tensor is zero, else returns the array form of + the permutation representing the canonical form of the tensor. + + Examples + ======== + + >>> from sympy.combinatorics.testutil import canonicalize_naive + >>> from sympy.combinatorics.tensor_can import get_symmetric_group_sgs + >>> from sympy.combinatorics import Permutation + >>> g = Permutation([1, 3, 2, 0, 4, 5]) + >>> base2, gens2 = get_symmetric_group_sgs(2) + >>> canonicalize_naive(g, [2, 3], 0, (base2, gens2, 2, 0)) + [0, 2, 1, 3, 4, 5] + """ + from sympy.combinatorics.perm_groups import PermutationGroup + from sympy.combinatorics.tensor_can import gens_products, dummy_sgs + from sympy.combinatorics.permutations import _af_rmul + v1 = [] + for i in range(len(v)): + base_i, gens_i, n_i, sym_i = v[i] + v1.append((base_i, gens_i, [[]]*n_i, sym_i)) + size, sbase, sgens = gens_products(*v1) + dgens = dummy_sgs(dummies, sym, size-2) + if isinstance(sym, int): + num_types = 1 + dummies = [dummies] + sym = [sym] + else: + num_types = len(sym) + dgens = [] + for i in range(num_types): + dgens.extend(dummy_sgs(dummies[i], sym[i], size - 2)) + S = PermutationGroup(sgens) + D = PermutationGroup([Permutation(x) for x in dgens]) + dlist = list(D.generate(af=True)) + g = g.array_form + st = set() + for s in S.generate(af=True): + h = _af_rmul(g, s) + for d in dlist: + q = tuple(_af_rmul(d, h)) + st.add(q) + a = list(st) + a.sort() + prev = (0,)*size + for h in a: + if h[:-2] == prev[:-2]: + if h[-1] != prev[-1]: + return 0 + prev = h + return list(a[0]) + + +def graph_certificate(gr): + """ + Return a certificate for the graph + + Parameters + ========== + + gr : adjacency list + + Explanation + =========== + + The graph is assumed to be unoriented and without + external lines. + + Associate to each vertex of the graph a symmetric tensor with + number of indices equal to the degree of the vertex; indices + are contracted when they correspond to the same line of the graph. + The canonical form of the tensor gives a certificate for the graph. + + This is not an efficient algorithm to get the certificate of a graph. + + Examples + ======== + + >>> from sympy.combinatorics.testutil import graph_certificate + >>> gr1 = {0:[1, 2, 3, 5], 1:[0, 2, 4], 2:[0, 1, 3, 4], 3:[0, 2, 4], 4:[1, 2, 3, 5], 5:[0, 4]} + >>> gr2 = {0:[1, 5], 1:[0, 2, 3, 4], 2:[1, 3, 5], 3:[1, 2, 4, 5], 4:[1, 3, 5], 5:[0, 2, 3, 4]} + >>> c1 = graph_certificate(gr1) + >>> c2 = graph_certificate(gr2) + >>> c1 + [0, 2, 4, 6, 1, 8, 10, 12, 3, 14, 16, 18, 5, 9, 15, 7, 11, 17, 13, 19, 20, 21] + >>> c1 == c2 + True + """ + from sympy.combinatorics.permutations import _af_invert + from sympy.combinatorics.tensor_can import get_symmetric_group_sgs, canonicalize + items = list(gr.items()) + items.sort(key=lambda x: len(x[1]), reverse=True) + pvert = [x[0] for x in items] + pvert = _af_invert(pvert) + + # the indices of the tensor are twice the number of lines of the graph + num_indices = 0 + for v, neigh in items: + num_indices += len(neigh) + # associate to each vertex its indices; for each line + # between two vertices assign the + # even index to the vertex which comes first in items, + # the odd index to the other vertex + vertices = [[] for i in items] + i = 0 + for v, neigh in items: + for v2 in neigh: + if pvert[v] < pvert[v2]: + vertices[pvert[v]].append(i) + vertices[pvert[v2]].append(i+1) + i += 2 + g = [] + for v in vertices: + g.extend(v) + assert len(g) == num_indices + g += [num_indices, num_indices + 1] + size = num_indices + 2 + assert sorted(g) == list(range(size)) + g = Permutation(g) + vlen = [0]*(len(vertices[0])+1) + for neigh in vertices: + vlen[len(neigh)] += 1 + v = [] + for i in range(len(vlen)): + n = vlen[i] + if n: + base, gens = get_symmetric_group_sgs(i) + v.append((base, gens, n, 0)) + v.reverse() + dummies = list(range(num_indices)) + can = canonicalize(g, dummies, 0, *v) + return can diff --git a/.venv/lib/python3.13/site-packages/sympy/combinatorics/util.py b/.venv/lib/python3.13/site-packages/sympy/combinatorics/util.py new file mode 100644 index 0000000000000000000000000000000000000000..fc73b02f94f4aae6f1b98bb3f0c837fd5a1d1e6d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/combinatorics/util.py @@ -0,0 +1,532 @@ +from sympy.combinatorics.permutations import Permutation, _af_invert, _af_rmul +from sympy.ntheory import isprime + +rmul = Permutation.rmul +_af_new = Permutation._af_new + +############################################ +# +# Utilities for computational group theory +# +############################################ + + +def _base_ordering(base, degree): + r""" + Order `\{0, 1, \dots, n-1\}` so that base points come first and in order. + + Parameters + ========== + + base : the base + degree : the degree of the associated permutation group + + Returns + ======= + + A list ``base_ordering`` such that ``base_ordering[point]`` is the + number of ``point`` in the ordering. + + Examples + ======== + + >>> from sympy.combinatorics import SymmetricGroup + >>> from sympy.combinatorics.util import _base_ordering + >>> S = SymmetricGroup(4) + >>> S.schreier_sims() + >>> _base_ordering(S.base, S.degree) + [0, 1, 2, 3] + + Notes + ===== + + This is used in backtrack searches, when we define a relation `\ll` on + the underlying set for a permutation group of degree `n`, + `\{0, 1, \dots, n-1\}`, so that if `(b_1, b_2, \dots, b_k)` is a base we + have `b_i \ll b_j` whenever `i>> from sympy.combinatorics.util import _check_cycles_alt_sym + >>> from sympy.combinatorics import Permutation + >>> a = Permutation([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [11, 12]]) + >>> _check_cycles_alt_sym(a) + False + >>> b = Permutation([[0, 1, 2, 3, 4, 5, 6], [7, 8, 9, 10]]) + >>> _check_cycles_alt_sym(b) + True + + See Also + ======== + + sympy.combinatorics.perm_groups.PermutationGroup.is_alt_sym + + """ + n = perm.size + af = perm.array_form + current_len = 0 + total_len = 0 + used = set() + for i in range(n//2): + if i not in used and i < n//2 - total_len: + current_len = 1 + used.add(i) + j = i + while af[j] != i: + current_len += 1 + j = af[j] + used.add(j) + total_len += current_len + if current_len > n//2 and current_len < n - 2 and isprime(current_len): + return True + return False + + +def _distribute_gens_by_base(base, gens): + r""" + Distribute the group elements ``gens`` by membership in basic stabilizers. + + Explanation + =========== + + Notice that for a base `(b_1, b_2, \dots, b_k)`, the basic stabilizers + are defined as `G^{(i)} = G_{b_1, \dots, b_{i-1}}` for + `i \in\{1, 2, \dots, k\}`. + + Parameters + ========== + + base : a sequence of points in `\{0, 1, \dots, n-1\}` + gens : a list of elements of a permutation group of degree `n`. + + Returns + ======= + list + List of length `k`, where `k` is the length of *base*. The `i`-th entry + contains those elements in *gens* which fix the first `i` elements of + *base* (so that the `0`-th entry is equal to *gens* itself). If no + element fixes the first `i` elements of *base*, the `i`-th element is + set to a list containing the identity element. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import DihedralGroup + >>> from sympy.combinatorics.util import _distribute_gens_by_base + >>> D = DihedralGroup(3) + >>> D.schreier_sims() + >>> D.strong_gens + [(0 1 2), (0 2), (1 2)] + >>> D.base + [0, 1] + >>> _distribute_gens_by_base(D.base, D.strong_gens) + [[(0 1 2), (0 2), (1 2)], + [(1 2)]] + + See Also + ======== + + _strong_gens_from_distr, _orbits_transversals_from_bsgs, + _handle_precomputed_bsgs + + """ + base_len = len(base) + degree = gens[0].size + stabs = [[] for _ in range(base_len)] + max_stab_index = 0 + for gen in gens: + j = 0 + while j < base_len - 1 and gen._array_form[base[j]] == base[j]: + j += 1 + if j > max_stab_index: + max_stab_index = j + for k in range(j + 1): + stabs[k].append(gen) + for i in range(max_stab_index + 1, base_len): + stabs[i].append(_af_new(list(range(degree)))) + return stabs + + +def _handle_precomputed_bsgs(base, strong_gens, transversals=None, + basic_orbits=None, strong_gens_distr=None): + """ + Calculate BSGS-related structures from those present. + + Explanation + =========== + + The base and strong generating set must be provided; if any of the + transversals, basic orbits or distributed strong generators are not + provided, they will be calculated from the base and strong generating set. + + Parameters + ========== + + base : the base + strong_gens : the strong generators + transversals : basic transversals + basic_orbits : basic orbits + strong_gens_distr : strong generators distributed by membership in basic stabilizers + + Returns + ======= + + (transversals, basic_orbits, strong_gens_distr) + where *transversals* are the basic transversals, *basic_orbits* are the + basic orbits, and *strong_gens_distr* are the strong generators distributed + by membership in basic stabilizers. + + Examples + ======== + + >>> from sympy.combinatorics.named_groups import DihedralGroup + >>> from sympy.combinatorics.util import _handle_precomputed_bsgs + >>> D = DihedralGroup(3) + >>> D.schreier_sims() + >>> _handle_precomputed_bsgs(D.base, D.strong_gens, + ... basic_orbits=D.basic_orbits) + ([{0: (2), 1: (0 1 2), 2: (0 2)}, {1: (2), 2: (1 2)}], [[0, 1, 2], [1, 2]], [[(0 1 2), (0 2), (1 2)], [(1 2)]]) + + See Also + ======== + + _orbits_transversals_from_bsgs, _distribute_gens_by_base + + """ + if strong_gens_distr is None: + strong_gens_distr = _distribute_gens_by_base(base, strong_gens) + if transversals is None: + if basic_orbits is None: + basic_orbits, transversals = \ + _orbits_transversals_from_bsgs(base, strong_gens_distr) + else: + transversals = \ + _orbits_transversals_from_bsgs(base, strong_gens_distr, + transversals_only=True) + else: + if basic_orbits is None: + base_len = len(base) + basic_orbits = [None]*base_len + for i in range(base_len): + basic_orbits[i] = list(transversals[i].keys()) + return transversals, basic_orbits, strong_gens_distr + + +def _orbits_transversals_from_bsgs(base, strong_gens_distr, + transversals_only=False, slp=False): + """ + Compute basic orbits and transversals from a base and strong generating set. + + Explanation + =========== + + The generators are provided as distributed across the basic stabilizers. + If the optional argument ``transversals_only`` is set to True, only the + transversals are returned. + + Parameters + ========== + + base : The base. + strong_gens_distr : Strong generators distributed by membership in basic stabilizers. + transversals_only : bool, default: False + A flag switching between returning only the + transversals and both orbits and transversals. + slp : bool, default: False + If ``True``, return a list of dictionaries containing the + generator presentations of the elements of the transversals, + i.e. the list of indices of generators from ``strong_gens_distr[i]`` + such that their product is the relevant transversal element. + + Examples + ======== + + >>> from sympy.combinatorics import SymmetricGroup + >>> from sympy.combinatorics.util import _distribute_gens_by_base + >>> S = SymmetricGroup(3) + >>> S.schreier_sims() + >>> strong_gens_distr = _distribute_gens_by_base(S.base, S.strong_gens) + >>> (S.base, strong_gens_distr) + ([0, 1], [[(0 1 2), (2)(0 1), (1 2)], [(1 2)]]) + + See Also + ======== + + _distribute_gens_by_base, _handle_precomputed_bsgs + + """ + from sympy.combinatorics.perm_groups import _orbit_transversal + base_len = len(base) + degree = strong_gens_distr[0][0].size + transversals = [None]*base_len + slps = [None]*base_len + if transversals_only is False: + basic_orbits = [None]*base_len + for i in range(base_len): + transversals[i], slps[i] = _orbit_transversal(degree, strong_gens_distr[i], + base[i], pairs=True, slp=True) + transversals[i] = dict(transversals[i]) + if transversals_only is False: + basic_orbits[i] = list(transversals[i].keys()) + if transversals_only: + return transversals + else: + if not slp: + return basic_orbits, transversals + return basic_orbits, transversals, slps + + +def _remove_gens(base, strong_gens, basic_orbits=None, strong_gens_distr=None): + """ + Remove redundant generators from a strong generating set. + + Parameters + ========== + + base : a base + strong_gens : a strong generating set relative to *base* + basic_orbits : basic orbits + strong_gens_distr : strong generators distributed by membership in basic stabilizers + + Returns + ======= + + A strong generating set with respect to ``base`` which is a subset of + ``strong_gens``. + + Examples + ======== + + >>> from sympy.combinatorics import SymmetricGroup + >>> from sympy.combinatorics.util import _remove_gens + >>> from sympy.combinatorics.testutil import _verify_bsgs + >>> S = SymmetricGroup(15) + >>> base, strong_gens = S.schreier_sims_incremental() + >>> new_gens = _remove_gens(base, strong_gens) + >>> len(new_gens) + 14 + >>> _verify_bsgs(S, base, new_gens) + True + + Notes + ===== + + This procedure is outlined in [1],p.95. + + References + ========== + + .. [1] Holt, D., Eick, B., O'Brien, E. + "Handbook of computational group theory" + + """ + from sympy.combinatorics.perm_groups import _orbit + base_len = len(base) + degree = strong_gens[0].size + if strong_gens_distr is None: + strong_gens_distr = _distribute_gens_by_base(base, strong_gens) + if basic_orbits is None: + basic_orbits = [] + for i in range(base_len): + basic_orbit = _orbit(degree, strong_gens_distr[i], base[i]) + basic_orbits.append(basic_orbit) + strong_gens_distr.append([]) + res = strong_gens[:] + for i in range(base_len - 1, -1, -1): + gens_copy = strong_gens_distr[i][:] + for gen in strong_gens_distr[i]: + if gen not in strong_gens_distr[i + 1]: + temp_gens = gens_copy[:] + temp_gens.remove(gen) + if temp_gens == []: + continue + temp_orbit = _orbit(degree, temp_gens, base[i]) + if temp_orbit == basic_orbits[i]: + gens_copy.remove(gen) + res.remove(gen) + return res + + +def _strip(g, base, orbits, transversals): + """ + Attempt to decompose a permutation using a (possibly partial) BSGS + structure. + + Explanation + =========== + + This is done by treating the sequence ``base`` as an actual base, and + the orbits ``orbits`` and transversals ``transversals`` as basic orbits and + transversals relative to it. + + This process is called "sifting". A sift is unsuccessful when a certain + orbit element is not found or when after the sift the decomposition + does not end with the identity element. + + The argument ``transversals`` is a list of dictionaries that provides + transversal elements for the orbits ``orbits``. + + Parameters + ========== + + g : permutation to be decomposed + base : sequence of points + orbits : list + A list in which the ``i``-th entry is an orbit of ``base[i]`` + under some subgroup of the pointwise stabilizer of ` + `base[0], base[1], ..., base[i - 1]``. The groups themselves are implicit + in this function since the only information we need is encoded in the orbits + and transversals + transversals : list + A list of orbit transversals associated with the orbits *orbits*. + + Examples + ======== + + >>> from sympy.combinatorics import Permutation, SymmetricGroup + >>> from sympy.combinatorics.util import _strip + >>> S = SymmetricGroup(5) + >>> S.schreier_sims() + >>> g = Permutation([0, 2, 3, 1, 4]) + >>> _strip(g, S.base, S.basic_orbits, S.basic_transversals) + ((4), 5) + + Notes + ===== + + The algorithm is described in [1],pp.89-90. The reason for returning + both the current state of the element being decomposed and the level + at which the sifting ends is that they provide important information for + the randomized version of the Schreier-Sims algorithm. + + References + ========== + + .. [1] Holt, D., Eick, B., O'Brien, E."Handbook of computational group theory" + + See Also + ======== + + sympy.combinatorics.perm_groups.PermutationGroup.schreier_sims + sympy.combinatorics.perm_groups.PermutationGroup.schreier_sims_random + + """ + h = g._array_form + base_len = len(base) + for i in range(base_len): + beta = h[base[i]] + if beta == base[i]: + continue + if beta not in orbits[i]: + return _af_new(h), i + 1 + u = transversals[i][beta]._array_form + h = _af_rmul(_af_invert(u), h) + return _af_new(h), base_len + 1 + + +def _strip_af(h, base, orbits, transversals, j, slp=[], slps={}): + """ + optimized _strip, with h, transversals and result in array form + if the stripped elements is the identity, it returns False, base_len + 1 + + j h[base[i]] == base[i] for i <= j + + """ + base_len = len(base) + for i in range(j+1, base_len): + beta = h[base[i]] + if beta == base[i]: + continue + if beta not in orbits[i]: + if not slp: + return h, i + 1 + return h, i + 1, slp + u = transversals[i][beta] + if h == u: + if not slp: + return False, base_len + 1 + return False, base_len + 1, slp + h = _af_rmul(_af_invert(u), h) + if slp: + u_slp = slps[i][beta][:] + u_slp.reverse() + u_slp = [(i, (g,)) for g in u_slp] + slp = u_slp + slp + if not slp: + return h, base_len + 1 + return h, base_len + 1, slp + + +def _strong_gens_from_distr(strong_gens_distr): + """ + Retrieve strong generating set from generators of basic stabilizers. + + This is just the union of the generators of the first and second basic + stabilizers. + + Parameters + ========== + + strong_gens_distr : strong generators distributed by membership in basic stabilizers + + Examples + ======== + + >>> from sympy.combinatorics import SymmetricGroup + >>> from sympy.combinatorics.util import (_strong_gens_from_distr, + ... _distribute_gens_by_base) + >>> S = SymmetricGroup(3) + >>> S.schreier_sims() + >>> S.strong_gens + [(0 1 2), (2)(0 1), (1 2)] + >>> strong_gens_distr = _distribute_gens_by_base(S.base, S.strong_gens) + >>> _strong_gens_from_distr(strong_gens_distr) + [(0 1 2), (2)(0 1), (1 2)] + + See Also + ======== + + _distribute_gens_by_base + + """ + if len(strong_gens_distr) == 1: + return strong_gens_distr[0][:] + else: + result = strong_gens_distr[0] + for gen in strong_gens_distr[1]: + if gen not in result: + result.append(gen) + return result diff --git a/.venv/lib/python3.13/site-packages/sympy/integrals/__init__.py b/.venv/lib/python3.13/site-packages/sympy/integrals/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e78fe96f84d68e9b119571eb22dedb7033811b23 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/integrals/__init__.py @@ -0,0 +1,45 @@ +"""Integration functions that integrate a SymPy expression. + + Examples + ======== + + >>> from sympy import integrate, sin + >>> from sympy.abc import x + >>> integrate(1/x,x) + log(x) + >>> integrate(sin(x),x) + -cos(x) +""" +from .integrals import integrate, Integral, line_integrate +from .transforms import (mellin_transform, inverse_mellin_transform, + MellinTransform, InverseMellinTransform, + laplace_transform, inverse_laplace_transform, + laplace_correspondence, laplace_initial_conds, + LaplaceTransform, InverseLaplaceTransform, + fourier_transform, inverse_fourier_transform, + FourierTransform, InverseFourierTransform, + sine_transform, inverse_sine_transform, + SineTransform, InverseSineTransform, + cosine_transform, inverse_cosine_transform, + CosineTransform, InverseCosineTransform, + hankel_transform, inverse_hankel_transform, + HankelTransform, InverseHankelTransform) +from .singularityfunctions import singularityintegrate + +__all__ = [ + 'integrate', 'Integral', 'line_integrate', + + 'mellin_transform', 'inverse_mellin_transform', 'MellinTransform', + 'InverseMellinTransform', 'laplace_transform', + 'inverse_laplace_transform', 'LaplaceTransform', + 'laplace_correspondence', 'laplace_initial_conds', + 'InverseLaplaceTransform', 'fourier_transform', + 'inverse_fourier_transform', 'FourierTransform', + 'InverseFourierTransform', 'sine_transform', 'inverse_sine_transform', + 'SineTransform', 'InverseSineTransform', 'cosine_transform', + 'inverse_cosine_transform', 'CosineTransform', 'InverseCosineTransform', + 'hankel_transform', 'inverse_hankel_transform', 'HankelTransform', + 'InverseHankelTransform', + + 'singularityintegrate', +] diff --git a/.venv/lib/python3.13/site-packages/sympy/integrals/deltafunctions.py b/.venv/lib/python3.13/site-packages/sympy/integrals/deltafunctions.py new file mode 100644 index 0000000000000000000000000000000000000000..ae9fef0b0010a313e0866a54d978024dd475f882 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/integrals/deltafunctions.py @@ -0,0 +1,201 @@ +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.core.sorting import default_sort_key +from sympy.functions import DiracDelta, Heaviside +from .integrals import Integral, integrate + + +def change_mul(node, x): + """change_mul(node, x) + + Rearranges the operands of a product, bringing to front any simple + DiracDelta expression. + + Explanation + =========== + + If no simple DiracDelta expression was found, then all the DiracDelta + expressions are simplified (using DiracDelta.expand(diracdelta=True, wrt=x)). + + Return: (dirac, new node) + Where: + o dirac is either a simple DiracDelta expression or None (if no simple + expression was found); + o new node is either a simplified DiracDelta expressions or None (if it + could not be simplified). + + Examples + ======== + + >>> from sympy import DiracDelta, cos + >>> from sympy.integrals.deltafunctions import change_mul + >>> from sympy.abc import x, y + >>> change_mul(x*y*DiracDelta(x)*cos(x), x) + (DiracDelta(x), x*y*cos(x)) + >>> change_mul(x*y*DiracDelta(x**2 - 1)*cos(x), x) + (None, x*y*cos(x)*DiracDelta(x - 1)/2 + x*y*cos(x)*DiracDelta(x + 1)/2) + >>> change_mul(x*y*DiracDelta(cos(x))*cos(x), x) + (None, None) + + See Also + ======== + + sympy.functions.special.delta_functions.DiracDelta + deltaintegrate + """ + + new_args = [] + dirac = None + + #Sorting is needed so that we consistently collapse the same delta; + #However, we must preserve the ordering of non-commutative terms + c, nc = node.args_cnc() + sorted_args = sorted(c, key=default_sort_key) + sorted_args.extend(nc) + + for arg in sorted_args: + if arg.is_Pow and isinstance(arg.base, DiracDelta): + new_args.append(arg.func(arg.base, arg.exp - 1)) + arg = arg.base + if dirac is None and (isinstance(arg, DiracDelta) and arg.is_simple(x)): + dirac = arg + else: + new_args.append(arg) + if not dirac: # there was no simple dirac + new_args = [] + for arg in sorted_args: + if isinstance(arg, DiracDelta): + new_args.append(arg.expand(diracdelta=True, wrt=x)) + elif arg.is_Pow and isinstance(arg.base, DiracDelta): + new_args.append(arg.func(arg.base.expand(diracdelta=True, wrt=x), arg.exp)) + else: + new_args.append(arg) + if new_args != sorted_args: + nnode = Mul(*new_args).expand() + else: # if the node didn't change there is nothing to do + nnode = None + return (None, nnode) + return (dirac, Mul(*new_args)) + + +def deltaintegrate(f, x): + """ + deltaintegrate(f, x) + + Explanation + =========== + + The idea for integration is the following: + + - If we are dealing with a DiracDelta expression, i.e. DiracDelta(g(x)), + we try to simplify it. + + If we could simplify it, then we integrate the resulting expression. + We already know we can integrate a simplified expression, because only + simple DiracDelta expressions are involved. + + If we couldn't simplify it, there are two cases: + + 1) The expression is a simple expression: we return the integral, + taking care if we are dealing with a Derivative or with a proper + DiracDelta. + + 2) The expression is not simple (i.e. DiracDelta(cos(x))): we can do + nothing at all. + + - If the node is a multiplication node having a DiracDelta term: + + First we expand it. + + If the expansion did work, then we try to integrate the expansion. + + If not, we try to extract a simple DiracDelta term, then we have two + cases: + + 1) We have a simple DiracDelta term, so we return the integral. + + 2) We didn't have a simple term, but we do have an expression with + simplified DiracDelta terms, so we integrate this expression. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy.integrals.deltafunctions import deltaintegrate + >>> from sympy import sin, cos, DiracDelta + >>> deltaintegrate(x*sin(x)*cos(x)*DiracDelta(x - 1), x) + sin(1)*cos(1)*Heaviside(x - 1) + >>> deltaintegrate(y**2*DiracDelta(x - z)*DiracDelta(y - z), y) + z**2*DiracDelta(x - z)*Heaviside(y - z) + + See Also + ======== + + sympy.functions.special.delta_functions.DiracDelta + sympy.integrals.integrals.Integral + """ + if not f.has(DiracDelta): + return None + + # g(x) = DiracDelta(h(x)) + if f.func == DiracDelta: + h = f.expand(diracdelta=True, wrt=x) + if h == f: # can't simplify the expression + #FIXME: the second term tells whether is DeltaDirac or Derivative + #For integrating derivatives of DiracDelta we need the chain rule + if f.is_simple(x): + if (len(f.args) <= 1 or f.args[1] == 0): + return Heaviside(f.args[0]) + else: + return (DiracDelta(f.args[0], f.args[1] - 1) / + f.args[0].as_poly().LC()) + else: # let's try to integrate the simplified expression + fh = integrate(h, x) + return fh + elif f.is_Mul or f.is_Pow: # g(x) = a*b*c*f(DiracDelta(h(x)))*d*e + g = f.expand() + if f != g: # the expansion worked + fh = integrate(g, x) + if fh is not None and not isinstance(fh, Integral): + return fh + else: + # no expansion performed, try to extract a simple DiracDelta term + deltaterm, rest_mult = change_mul(f, x) + + if not deltaterm: + if rest_mult: + fh = integrate(rest_mult, x) + return fh + else: + from sympy.solvers import solve + deltaterm = deltaterm.expand(diracdelta=True, wrt=x) + if deltaterm.is_Mul: # Take out any extracted factors + deltaterm, rest_mult_2 = change_mul(deltaterm, x) + rest_mult = rest_mult*rest_mult_2 + point = solve(deltaterm.args[0], x)[0] + + # Return the largest hyperreal term left after + # repeated integration by parts. For example, + # + # integrate(y*DiracDelta(x, 1),x) == y*DiracDelta(x,0), not 0 + # + # This is so Integral(y*DiracDelta(x).diff(x),x).doit() + # will return y*DiracDelta(x) instead of 0 or DiracDelta(x), + # both of which are correct everywhere the value is defined + # but give wrong answers for nested integration. + n = (0 if len(deltaterm.args)==1 else deltaterm.args[1]) + m = 0 + while n >= 0: + r = S.NegativeOne**n*rest_mult.diff(x, n).subs(x, point) + if r.is_zero: + n -= 1 + m += 1 + else: + if m == 0: + return r*Heaviside(x - point) + else: + return r*DiracDelta(x,m-1) + # In some very weak sense, x=0 is still a singularity, + # but we hope will not be of any practical consequence. + return S.Zero + return None diff --git a/.venv/lib/python3.13/site-packages/sympy/integrals/heurisch.py b/.venv/lib/python3.13/site-packages/sympy/integrals/heurisch.py new file mode 100644 index 0000000000000000000000000000000000000000..a27e2700afd08db16c7a86020eabe5feeb6e1c85 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/integrals/heurisch.py @@ -0,0 +1,781 @@ +from __future__ import annotations + +from collections import defaultdict +from functools import reduce +from itertools import permutations + +from sympy.core.add import Add +from sympy.core.basic import Basic +from sympy.core.mul import Mul +from sympy.core.symbol import Wild, Dummy, Symbol +from sympy.core.basic import sympify +from sympy.core.numbers import Rational, pi, I +from sympy.core.relational import Eq, Ne +from sympy.core.singleton import S +from sympy.core.sorting import ordered +from sympy.core.traversal import iterfreeargs + +from sympy.functions import exp, sin, cos, tan, cot, asin, atan +from sympy.functions import log, sinh, cosh, tanh, coth, asinh +from sympy.functions import sqrt, erf, erfi, li, Ei +from sympy.functions import besselj, bessely, besseli, besselk +from sympy.functions import hankel1, hankel2, jn, yn +from sympy.functions.elementary.complexes import Abs, re, im, sign, arg +from sympy.functions.elementary.exponential import LambertW +from sympy.functions.elementary.integers import floor, ceiling +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.delta_functions import Heaviside, DiracDelta + +from sympy.simplify.radsimp import collect + +from sympy.logic.boolalg import And, Or +from sympy.utilities.iterables import uniq + +from sympy.polys import quo, gcd, lcm, factor_list, cancel, PolynomialError +from sympy.polys.monomials import itermonomials +from sympy.polys.polyroots import root_factors + +from sympy.polys.rings import PolyRing +from sympy.polys.solvers import solve_lin_sys +from sympy.polys.constructor import construct_domain + +from sympy.integrals.integrals import integrate + + +def components(f, x): + """ + Returns a set of all functional components of the given expression + which includes symbols, function applications and compositions and + non-integer powers. Fractional powers are collected with + minimal, positive exponents. + + Examples + ======== + + >>> from sympy import cos, sin + >>> from sympy.abc import x + >>> from sympy.integrals.heurisch import components + + >>> components(sin(x)*cos(x)**2, x) + {x, sin(x), cos(x)} + + See Also + ======== + + heurisch + """ + result = set() + + if f.has_free(x): + if f.is_symbol and f.is_commutative: + result.add(f) + elif f.is_Function or f.is_Derivative: + for g in f.args: + result |= components(g, x) + + result.add(f) + elif f.is_Pow: + result |= components(f.base, x) + + if not f.exp.is_Integer: + if f.exp.is_Rational: + result.add(f.base**Rational(1, f.exp.q)) + else: + result |= components(f.exp, x) | {f} + else: + for g in f.args: + result |= components(g, x) + + return result + +# name -> [] of symbols +_symbols_cache: dict[str, list[Dummy]] = {} + + +# NB @cacheit is not convenient here +def _symbols(name, n): + """get vector of symbols local to this module""" + try: + lsyms = _symbols_cache[name] + except KeyError: + lsyms = [] + _symbols_cache[name] = lsyms + + while len(lsyms) < n: + lsyms.append( Dummy('%s%i' % (name, len(lsyms))) ) + + return lsyms[:n] + + +def heurisch_wrapper(f, x, rewrite=False, hints=None, mappings=None, retries=3, + degree_offset=0, unnecessary_permutations=None, + _try_heurisch=None): + """ + A wrapper around the heurisch integration algorithm. + + Explanation + =========== + + This method takes the result from heurisch and checks for poles in the + denominator. For each of these poles, the integral is reevaluated, and + the final integration result is given in terms of a Piecewise. + + Examples + ======== + + >>> from sympy import cos, symbols + >>> from sympy.integrals.heurisch import heurisch, heurisch_wrapper + >>> n, x = symbols('n x') + >>> heurisch(cos(n*x), x) + sin(n*x)/n + >>> heurisch_wrapper(cos(n*x), x) + Piecewise((sin(n*x)/n, Ne(n, 0)), (x, True)) + + See Also + ======== + + heurisch + """ + from sympy.solvers.solvers import solve, denoms + f = sympify(f) + if not f.has_free(x): + return f*x + + res = heurisch(f, x, rewrite, hints, mappings, retries, degree_offset, + unnecessary_permutations, _try_heurisch) + if not isinstance(res, Basic): + return res + + # We consider each denominator in the expression, and try to find + # cases where one or more symbolic denominator might be zero. The + # conditions for these cases are stored in the list slns. + # + # Since denoms returns a set we use ordered. This is important because the + # ordering of slns determines the order of the resulting Piecewise so we + # need a deterministic order here to make the output deterministic. + slns = [] + for d in ordered(denoms(res)): + try: + slns += solve([d], dict=True, exclude=(x,)) + except NotImplementedError: + pass + if not slns: + return res + slns = list(uniq(slns)) + # Remove the solutions corresponding to poles in the original expression. + slns0 = [] + for d in denoms(f): + try: + slns0 += solve([d], dict=True, exclude=(x,)) + except NotImplementedError: + pass + slns = [s for s in slns if s not in slns0] + if not slns: + return res + if len(slns) > 1: + eqs = [] + for sub_dict in slns: + eqs.extend([Eq(key, value) for key, value in sub_dict.items()]) + slns = solve(eqs, dict=True, exclude=(x,)) + slns + # For each case listed in the list slns, we reevaluate the integral. + pairs = [] + for sub_dict in slns: + expr = heurisch(f.subs(sub_dict), x, rewrite, hints, mappings, retries, + degree_offset, unnecessary_permutations, + _try_heurisch) + cond = And(*[Eq(key, value) for key, value in sub_dict.items()]) + generic = Or(*[Ne(key, value) for key, value in sub_dict.items()]) + if expr is None: + expr = integrate(f.subs(sub_dict),x) + pairs.append((expr, cond)) + # If there is one condition, put the generic case first. Otherwise, + # doing so may lead to longer Piecewise formulas + if len(pairs) == 1: + pairs = [(heurisch(f, x, rewrite, hints, mappings, retries, + degree_offset, unnecessary_permutations, + _try_heurisch), + generic), + (pairs[0][0], True)] + else: + pairs.append((heurisch(f, x, rewrite, hints, mappings, retries, + degree_offset, unnecessary_permutations, + _try_heurisch), + True)) + return Piecewise(*pairs) + +class BesselTable: + """ + Derivatives of Bessel functions of orders n and n-1 + in terms of each other. + + See the docstring of DiffCache. + """ + + def __init__(self): + self.table = {} + self.n = Dummy('n') + self.z = Dummy('z') + self._create_table() + + def _create_table(t): + table, n, z = t.table, t.n, t.z + for f in (besselj, bessely, hankel1, hankel2): + table[f] = (f(n-1, z) - n*f(n, z)/z, + (n-1)*f(n-1, z)/z - f(n, z)) + + f = besseli + table[f] = (f(n-1, z) - n*f(n, z)/z, + (n-1)*f(n-1, z)/z + f(n, z)) + f = besselk + table[f] = (-f(n-1, z) - n*f(n, z)/z, + (n-1)*f(n-1, z)/z - f(n, z)) + + for f in (jn, yn): + table[f] = (f(n-1, z) - (n+1)*f(n, z)/z, + (n-1)*f(n-1, z)/z - f(n, z)) + + def diffs(t, f, n, z): + if f in t.table: + diff0, diff1 = t.table[f] + repl = [(t.n, n), (t.z, z)] + return (diff0.subs(repl), diff1.subs(repl)) + + def has(t, f): + return f in t.table + +_bessel_table = None + +class DiffCache: + """ + Store for derivatives of expressions. + + Explanation + =========== + + The standard form of the derivative of a Bessel function of order n + contains two Bessel functions of orders n-1 and n+1, respectively. + Such forms cannot be used in parallel Risch algorithm, because + there is a linear recurrence relation between the three functions + while the algorithm expects that functions and derivatives are + represented in terms of algebraically independent transcendentals. + + The solution is to take two of the functions, e.g., those of orders + n and n-1, and to express the derivatives in terms of the pair. + To guarantee that the proper form is used the two derivatives are + cached as soon as one is encountered. + + Derivatives of other functions are also cached at no extra cost. + All derivatives are with respect to the same variable `x`. + """ + + def __init__(self, x): + self.cache = {} + self.x = x + + global _bessel_table + if not _bessel_table: + _bessel_table = BesselTable() + + def get_diff(self, f): + cache = self.cache + + if f in cache: + pass + elif (not hasattr(f, 'func') or + not _bessel_table.has(f.func)): + cache[f] = cancel(f.diff(self.x)) + else: + n, z = f.args + d0, d1 = _bessel_table.diffs(f.func, n, z) + dz = self.get_diff(z) + cache[f] = d0*dz + cache[f.func(n-1, z)] = d1*dz + + return cache[f] + +def heurisch(f, x, rewrite=False, hints=None, mappings=None, retries=3, + degree_offset=0, unnecessary_permutations=None, + _try_heurisch=None): + """ + Compute indefinite integral using heuristic Risch algorithm. + + Explanation + =========== + + This is a heuristic approach to indefinite integration in finite + terms using the extended heuristic (parallel) Risch algorithm, based + on Manuel Bronstein's "Poor Man's Integrator". + + The algorithm supports various classes of functions including + transcendental elementary or special functions like Airy, + Bessel, Whittaker and Lambert. + + Note that this algorithm is not a decision procedure. If it isn't + able to compute the antiderivative for a given function, then this is + not a proof that such a functions does not exist. One should use + recursive Risch algorithm in such case. It's an open question if + this algorithm can be made a full decision procedure. + + This is an internal integrator procedure. You should use top level + 'integrate' function in most cases, as this procedure needs some + preprocessing steps and otherwise may fail. + + Specification + ============= + + heurisch(f, x, rewrite=False, hints=None) + + where + f : expression + x : symbol + + rewrite -> force rewrite 'f' in terms of 'tan' and 'tanh' + hints -> a list of functions that may appear in anti-derivate + + - hints = None --> no suggestions at all + - hints = [ ] --> try to figure out + - hints = [f1, ..., fn] --> we know better + + Examples + ======== + + >>> from sympy import tan + >>> from sympy.integrals.heurisch import heurisch + >>> from sympy.abc import x, y + + >>> heurisch(y*tan(x), x) + y*log(tan(x)**2 + 1)/2 + + See Manuel Bronstein's "Poor Man's Integrator": + + References + ========== + + .. [1] https://www-sop.inria.fr/cafe/Manuel.Bronstein/pmint/index.html + + For more information on the implemented algorithm refer to: + + .. [2] K. Geddes, L. Stefanus, On the Risch-Norman Integration + Method and its Implementation in Maple, Proceedings of + ISSAC'89, ACM Press, 212-217. + + .. [3] J. H. Davenport, On the Parallel Risch Algorithm (I), + Proceedings of EUROCAM'82, LNCS 144, Springer, 144-157. + + .. [4] J. H. Davenport, On the Parallel Risch Algorithm (III): + Use of Tangents, SIGSAM Bulletin 16 (1982), 3-6. + + .. [5] J. H. Davenport, B. M. Trager, On the Parallel Risch + Algorithm (II), ACM Transactions on Mathematical + Software 11 (1985), 356-362. + + See Also + ======== + + sympy.integrals.integrals.Integral.doit + sympy.integrals.integrals.Integral + sympy.integrals.heurisch.components + """ + f = sympify(f) + + # There are some functions that Heurisch cannot currently handle, + # so do not even try. + # Set _try_heurisch=True to skip this check + if _try_heurisch is not True: + if f.has(Abs, re, im, sign, Heaviside, DiracDelta, floor, ceiling, arg): + return + + if not f.has_free(x): + return f*x + + if not f.is_Add: + indep, f = f.as_independent(x) + else: + indep = S.One + + rewritables = { + (sin, cos, cot): tan, + (sinh, cosh, coth): tanh, + } + + if rewrite: + for candidates, rule in rewritables.items(): + f = f.rewrite(candidates, rule) + else: + for candidates in rewritables.keys(): + if f.has(*candidates): + break + else: + rewrite = True + + terms = components(f, x) + dcache = DiffCache(x) + + if hints is not None: + if not hints: + a = Wild('a', exclude=[x]) + b = Wild('b', exclude=[x]) + c = Wild('c', exclude=[x]) + + for g in set(terms): # using copy of terms + if g.is_Function: + if isinstance(g, li): + M = g.args[0].match(a*x**b) + + if M is not None: + terms.add( x*(li(M[a]*x**M[b]) - (M[a]*x**M[b])**(-1/M[b])*Ei((M[b]+1)*log(M[a]*x**M[b])/M[b])) ) + #terms.add( x*(li(M[a]*x**M[b]) - (x**M[b])**(-1/M[b])*Ei((M[b]+1)*log(M[a]*x**M[b])/M[b])) ) + #terms.add( x*(li(M[a]*x**M[b]) - x*Ei((M[b]+1)*log(M[a]*x**M[b])/M[b])) ) + #terms.add( li(M[a]*x**M[b]) - Ei((M[b]+1)*log(M[a]*x**M[b])/M[b]) ) + + elif isinstance(g, exp): + M = g.args[0].match(a*x**2) + + if M is not None: + if M[a].is_positive: + terms.add(erfi(sqrt(M[a])*x)) + else: # M[a].is_negative or unknown + terms.add(erf(sqrt(-M[a])*x)) + + M = g.args[0].match(a*x**2 + b*x + c) + + if M is not None: + if M[a].is_positive: + terms.add(sqrt(pi/4*(-M[a]))*exp(M[c] - M[b]**2/(4*M[a]))* + erfi(sqrt(M[a])*x + M[b]/(2*sqrt(M[a])))) + elif M[a].is_negative: + terms.add(sqrt(pi/4*(-M[a]))*exp(M[c] - M[b]**2/(4*M[a]))* + erf(sqrt(-M[a])*x - M[b]/(2*sqrt(-M[a])))) + + M = g.args[0].match(a*log(x)**2) + + if M is not None: + if M[a].is_positive: + terms.add(erfi(sqrt(M[a])*log(x) + 1/(2*sqrt(M[a])))) + if M[a].is_negative: + terms.add(erf(sqrt(-M[a])*log(x) - 1/(2*sqrt(-M[a])))) + + elif g.is_Pow: + if g.exp.is_Rational and g.exp.q == 2: + M = g.base.match(a*x**2 + b) + + if M is not None and M[b].is_positive: + if M[a].is_positive: + terms.add(asinh(sqrt(M[a]/M[b])*x)) + elif M[a].is_negative: + terms.add(asin(sqrt(-M[a]/M[b])*x)) + + M = g.base.match(a*x**2 - b) + + if M is not None and M[b].is_positive: + if M[a].is_positive: + dF = 1/sqrt(M[a]*x**2 - M[b]) + F = log(2*sqrt(M[a])*sqrt(M[a]*x**2 - M[b]) + 2*M[a]*x)/sqrt(M[a]) + dcache.cache[F] = dF # hack: F.diff(x) doesn't automatically simplify to f + terms.add(F) + elif M[a].is_negative: + terms.add(-M[b]/2*sqrt(-M[a])* + atan(sqrt(-M[a])*x/sqrt(M[a]*x**2 - M[b]))) + + else: + terms |= set(hints) + + for g in set(terms): # using copy of terms + terms |= components(dcache.get_diff(g), x) + + # XXX: The commented line below makes heurisch more deterministic wrt + # PYTHONHASHSEED and the iteration order of sets. There are other places + # where sets are iterated over but this one is possibly the most important. + # Theoretically the order here should not matter but different orderings + # can expose potential bugs in the different code paths so potentially it + # is better to keep the non-determinism. + # + # terms = list(ordered(terms)) + + # TODO: caching is significant factor for why permutations work at all. Change this. + V = _symbols('x', len(terms)) + + + # sort mapping expressions from largest to smallest (last is always x). + mapping = list(reversed(list(zip(*ordered( # + [(a[0].as_independent(x)[1], a) for a in zip(terms, V)])))[1])) # + rev_mapping = {v: k for k, v in mapping} # + if mappings is None: # + # optimizing the number of permutations of mapping # + assert mapping[-1][0] == x # if not, find it and correct this comment + unnecessary_permutations = [mapping.pop(-1)] + # permute types of objects + types = defaultdict(list) + for i in mapping: + e, _ = i + types[type(e)].append(i) + mapping = [types[i] for i in types] + def _iter_mappings(): + for i in permutations(mapping): + # make the expression of a given type be ordered + yield [j for i in i for j in ordered(i)] + mappings = _iter_mappings() + else: + unnecessary_permutations = unnecessary_permutations or [] + + def _substitute(expr): + return expr.subs(mapping) + + for mapping in mappings: + mapping = list(mapping) + mapping = mapping + unnecessary_permutations + diffs = [ _substitute(dcache.get_diff(g)) for g in terms ] + denoms = [ g.as_numer_denom()[1] for g in diffs ] + if all(h.is_polynomial(*V) for h in denoms) and _substitute(f).is_rational_function(*V): + denom = reduce(lambda p, q: lcm(p, q, *V), denoms) + break + else: + if not rewrite: + result = heurisch(f, x, rewrite=True, hints=hints, + unnecessary_permutations=unnecessary_permutations) + + if result is not None: + return indep*result + return None + + numers = [ cancel(denom*g) for g in diffs ] + def _derivation(h): + return Add(*[ d * h.diff(v) for d, v in zip(numers, V) ]) + + def _deflation(p): + for y in V: + if not p.has(y): + continue + + if _derivation(p) is not S.Zero: + c, q = p.as_poly(y).primitive() + return _deflation(c)*gcd(q, q.diff(y)).as_expr() + + return p + + def _splitter(p): + for y in V: + if not p.has(y): + continue + + if _derivation(y) is not S.Zero: + c, q = p.as_poly(y).primitive() + + q = q.as_expr() + + h = gcd(q, _derivation(q), y) + s = quo(h, gcd(q, q.diff(y), y), y) + + c_split = _splitter(c) + + if s.as_poly(y).degree() == 0: + return (c_split[0], q * c_split[1]) + + q_split = _splitter(cancel(q / s)) + + return (c_split[0]*q_split[0]*s, c_split[1]*q_split[1]) + + return (S.One, p) + + special = {} + + for term in terms: + if term.is_Function: + if isinstance(term, tan): + special[1 + _substitute(term)**2] = False + elif isinstance(term, tanh): + special[1 + _substitute(term)] = False + special[1 - _substitute(term)] = False + elif isinstance(term, LambertW): + special[_substitute(term)] = True + + F = _substitute(f) + + P, Q = F.as_numer_denom() + + u_split = _splitter(denom) + v_split = _splitter(Q) + + polys = set(list(v_split) + [ u_split[0] ] + list(special.keys())) + + s = u_split[0] * Mul(*[ k for k, v in special.items() if v ]) + polified = [ p.as_poly(*V) for p in [s, P, Q] ] + + if None in polified: + return None + + #--- definitions for _integrate + a, b, c = [ p.total_degree() for p in polified ] + + poly_denom = (s * v_split[0] * _deflation(v_split[1])).as_expr() + + def _exponent(g): + if g.is_Pow: + if g.exp.is_Rational and g.exp.q != 1: + if g.exp.p > 0: + return g.exp.p + g.exp.q - 1 + else: + return abs(g.exp.p + g.exp.q) + else: + return 1 + elif not g.is_Atom and g.args: + return max(_exponent(h) for h in g.args) + else: + return 1 + + A, B = _exponent(f), a + max(b, c) + + if A > 1 and B > 1: + monoms = tuple(ordered(itermonomials(V, A + B - 1 + degree_offset))) + else: + monoms = tuple(ordered(itermonomials(V, A + B + degree_offset))) + + poly_coeffs = _symbols('A', len(monoms)) + + poly_part = Add(*[ poly_coeffs[i]*monomial + for i, monomial in enumerate(monoms) ]) + + reducibles = set() + + for poly in ordered(polys): + coeff, factors = factor_list(poly, *V) + reducibles.add(coeff) + reducibles.update(fact for fact, mul in factors) + + def _integrate(field=None): + atans = set() + pairs = set() + + if field == 'Q': + irreducibles = set(reducibles) + else: + setV = set(V) + irreducibles = set() + for poly in ordered(reducibles): + zV = setV & set(iterfreeargs(poly)) + for z in ordered(zV): + s = set(root_factors(poly, z, filter=field)) + irreducibles |= s + break + + log_part, atan_part = [], [] + + for poly in ordered(irreducibles): + m = collect(poly, I, evaluate=False) + y = m.get(I, S.Zero) + if y: + x = m.get(S.One, S.Zero) + if x.has(I) or y.has(I): + continue # nontrivial x + I*y + pairs.add((x, y)) + irreducibles.remove(poly) + + while pairs: + x, y = pairs.pop() + if (x, -y) in pairs: + pairs.remove((x, -y)) + # Choosing b with no minus sign + if y.could_extract_minus_sign(): + y = -y + irreducibles.add(x*x + y*y) + atans.add(atan(x/y)) + else: + irreducibles.add(x + I*y) + + + B = _symbols('B', len(irreducibles)) + C = _symbols('C', len(atans)) + + # Note: the ordering matters here + for poly, b in reversed(list(zip(ordered(irreducibles), B))): + if poly.has(*V): + poly_coeffs.append(b) + log_part.append(b * log(poly)) + + for poly, c in reversed(list(zip(ordered(atans), C))): + if poly.has(*V): + poly_coeffs.append(c) + atan_part.append(c * poly) + + # TODO: Currently it's better to use symbolic expressions here instead + # of rational functions, because it's simpler and FracElement doesn't + # give big speed improvement yet. This is because cancellation is slow + # due to slow polynomial GCD algorithms. If this gets improved then + # revise this code. + candidate = poly_part/poly_denom + Add(*log_part) + Add(*atan_part) + h = F - _derivation(candidate) / denom + raw_numer = h.as_numer_denom()[0] + + # Rewrite raw_numer as a polynomial in K[coeffs][V] where K is a field + # that we have to determine. We can't use simply atoms() because log(3), + # sqrt(y) and similar expressions can appear, leading to non-trivial + # domains. + syms = set(poly_coeffs) | set(V) + non_syms = set() + + def find_non_syms(expr): + if expr.is_Integer or expr.is_Rational: + pass # ignore trivial numbers + elif expr in syms: + pass # ignore variables + elif not expr.has_free(*syms): + non_syms.add(expr) + elif expr.is_Add or expr.is_Mul or expr.is_Pow: + list(map(find_non_syms, expr.args)) + else: + # TODO: Non-polynomial expression. This should have been + # filtered out at an earlier stage. + raise PolynomialError + + try: + find_non_syms(raw_numer) + except PolynomialError: + return None + else: + ground, _ = construct_domain(non_syms, field=True) + + coeff_ring = PolyRing(poly_coeffs, ground) + ring = PolyRing(V, coeff_ring) + try: + numer = ring.from_expr(raw_numer) + except ValueError: + raise PolynomialError + solution = solve_lin_sys(numer.coeffs(), coeff_ring, _raw=False) + + if solution is None: + return None + else: + return candidate.xreplace(solution).xreplace( + dict(zip(poly_coeffs, [S.Zero]*len(poly_coeffs)))) + + if all(isinstance(_, Symbol) for _ in V): + more_free = F.free_symbols - set(V) + else: + Fd = F.as_dummy() + more_free = Fd.xreplace(dict(zip(V, (Dummy() for _ in V))) + ).free_symbols & Fd.free_symbols + if not more_free: + # all free generators are identified in V + solution = _integrate('Q') + + if solution is None: + solution = _integrate() + else: + solution = _integrate() + + if solution is not None: + antideriv = solution.subs(rev_mapping) + antideriv = cancel(antideriv).expand() + + if antideriv.is_Add: + antideriv = antideriv.as_independent(x)[1] + + return indep*antideriv + else: + if retries >= 0: + result = heurisch(f, x, mappings=mappings, rewrite=rewrite, hints=hints, retries=retries - 1, unnecessary_permutations=unnecessary_permutations) + + if result is not None: + return indep*result + + return None diff --git a/.venv/lib/python3.13/site-packages/sympy/integrals/integrals.py b/.venv/lib/python3.13/site-packages/sympy/integrals/integrals.py new file mode 100644 index 0000000000000000000000000000000000000000..b9ed4a22802acc455f5162e109fc575223c97338 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/integrals/integrals.py @@ -0,0 +1,1640 @@ +from __future__ import annotations + +from sympy.concrete.expr_with_limits import AddWithLimits +from sympy.core.add import Add +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.exprtools import factor_terms +from sympy.core.function import diff +from sympy.core.logic import fuzzy_bool +from sympy.core.mul import Mul +from sympy.core.numbers import oo, pi +from sympy.core.relational import Ne +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, Wild) +from sympy.core.sympify import sympify +from sympy.functions import Piecewise, sqrt, piecewise_fold, tan, cot, atan +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.complexes import Abs, sign +from sympy.functions.elementary.miscellaneous import Min, Max +from sympy.functions.special.singularity_functions import Heaviside +from .rationaltools import ratint +from sympy.matrices import MatrixBase +from sympy.polys import Poly, PolynomialError +from sympy.series.formal import FormalPowerSeries +from sympy.series.limits import limit +from sympy.series.order import Order +from sympy.tensor.functions import shape +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.iterables import is_sequence +from sympy.utilities.misc import filldedent + + +class Integral(AddWithLimits): + """Represents unevaluated integral.""" + + __slots__ = () + + args: tuple[Expr, Tuple] # type: ignore + + def __new__(cls, function, *symbols, **assumptions) -> Integral: + """Create an unevaluated integral. + + Explanation + =========== + + Arguments are an integrand followed by one or more limits. + + If no limits are given and there is only one free symbol in the + expression, that symbol will be used, otherwise an error will be + raised. + + >>> from sympy import Integral + >>> from sympy.abc import x, y + >>> Integral(x) + Integral(x, x) + >>> Integral(y) + Integral(y, y) + + When limits are provided, they are interpreted as follows (using + ``x`` as though it were the variable of integration): + + (x,) or x - indefinite integral + (x, a) - "evaluate at" integral is an abstract antiderivative + (x, a, b) - definite integral + + The ``as_dummy`` method can be used to see which symbols cannot be + targeted by subs: those with a prepended underscore cannot be + changed with ``subs``. (Also, the integration variables themselves -- + the first element of a limit -- can never be changed by subs.) + + >>> i = Integral(x, x) + >>> at = Integral(x, (x, x)) + >>> i.as_dummy() + Integral(x, x) + >>> at.as_dummy() + Integral(_0, (_0, x)) + + """ + + #This will help other classes define their own definitions + #of behaviour with Integral. + if hasattr(function, '_eval_Integral'): + return function._eval_Integral(*symbols, **assumptions) + + if isinstance(function, Poly): + sympy_deprecation_warning( + """ + integrate(Poly) and Integral(Poly) are deprecated. Instead, + use the Poly.integrate() method, or convert the Poly to an + Expr first with the Poly.as_expr() method. + """, + deprecated_since_version="1.6", + active_deprecations_target="deprecated-integrate-poly") + + obj = AddWithLimits.__new__(cls, function, *symbols, **assumptions) + return obj + + def __getnewargs__(self): + return (self.function,) + tuple([tuple(xab) for xab in self.limits]) + + @property + def free_symbols(self): + """ + This method returns the symbols that will exist when the + integral is evaluated. This is useful if one is trying to + determine whether an integral depends on a certain + symbol or not. + + Examples + ======== + + >>> from sympy import Integral + >>> from sympy.abc import x, y + >>> Integral(x, (x, y, 1)).free_symbols + {y} + + See Also + ======== + + sympy.concrete.expr_with_limits.ExprWithLimits.function + sympy.concrete.expr_with_limits.ExprWithLimits.limits + sympy.concrete.expr_with_limits.ExprWithLimits.variables + """ + return super().free_symbols + + def _eval_is_zero(self): + # This is a very naive and quick test, not intended to do the integral to + # answer whether it is zero or not, e.g. Integral(sin(x), (x, 0, 2*pi)) + # is zero but this routine should return None for that case. But, like + # Mul, there are trivial situations for which the integral will be + # zero so we check for those. + if self.function.is_zero: + return True + got_none = False + for l in self.limits: + if len(l) == 3: + z = (l[1] == l[2]) or (l[1] - l[2]).is_zero + if z: + return True + elif z is None: + got_none = True + free = self.function.free_symbols + for xab in self.limits: + if len(xab) == 1: + free.add(xab[0]) + continue + if len(xab) == 2 and xab[0] not in free: + if xab[1].is_zero: + return True + elif xab[1].is_zero is None: + got_none = True + # take integration symbol out of free since it will be replaced + # with the free symbols in the limits + free.discard(xab[0]) + # add in the new symbols + for i in xab[1:]: + free.update(i.free_symbols) + if self.function.is_zero is False and got_none is False: + return False + + def transform(self, x, u): + r""" + Performs a change of variables from `x` to `u` using the relationship + given by `x` and `u` which will define the transformations `f` and `F` + (which are inverses of each other) as follows: + + 1) If `x` is a Symbol (which is a variable of integration) then `u` + will be interpreted as some function, f(u), with inverse F(u). + This, in effect, just makes the substitution of x with f(x). + + 2) If `u` is a Symbol then `x` will be interpreted as some function, + F(x), with inverse f(u). This is commonly referred to as + u-substitution. + + Once f and F have been identified, the transformation is made as + follows: + + .. math:: \int_a^b x \mathrm{d}x \rightarrow \int_{F(a)}^{F(b)} f(x) + \frac{\mathrm{d}}{\mathrm{d}x} + + where `F(x)` is the inverse of `f(x)` and the limits and integrand have + been corrected so as to retain the same value after integration. + + Notes + ===== + + The mappings, F(x) or f(u), must lead to a unique integral. Linear + or rational linear expression, ``2*x``, ``1/x`` and ``sqrt(x)``, will + always work; quadratic expressions like ``x**2 - 1`` are acceptable + as long as the resulting integrand does not depend on the sign of + the solutions (see examples). + + The integral will be returned unchanged if ``x`` is not a variable of + integration. + + ``x`` must be (or contain) only one of of the integration variables. If + ``u`` has more than one free symbol then it should be sent as a tuple + (``u``, ``uvar``) where ``uvar`` identifies which variable is replacing + the integration variable. + XXX can it contain another integration variable? + + Examples + ======== + + >>> from sympy.abc import a, x, u + >>> from sympy import Integral, cos, sqrt + + >>> i = Integral(x*cos(x**2 - 1), (x, 0, 1)) + + transform can change the variable of integration + + >>> i.transform(x, u) + Integral(u*cos(u**2 - 1), (u, 0, 1)) + + transform can perform u-substitution as long as a unique + integrand is obtained: + + >>> ui = i.transform(x**2 - 1, u) + >>> ui + Integral(cos(u)/2, (u, -1, 0)) + + This attempt fails because x = +/-sqrt(u + 1) and the + sign does not cancel out of the integrand: + + >>> Integral(cos(x**2 - 1), (x, 0, 1)).transform(x**2 - 1, u) + Traceback (most recent call last): + ... + ValueError: + The mapping between F(x) and f(u) did not give a unique integrand. + + transform can do a substitution. Here, the previous + result is transformed back into the original expression + using "u-substitution": + + >>> ui.transform(sqrt(u + 1), x) == i + True + + We can accomplish the same with a regular substitution: + + >>> ui.transform(u, x**2 - 1) == i + True + + If the `x` does not contain a symbol of integration then + the integral will be returned unchanged. Integral `i` does + not have an integration variable `a` so no change is made: + + >>> i.transform(a, x) == i + True + + When `u` has more than one free symbol the symbol that is + replacing `x` must be identified by passing `u` as a tuple: + + >>> Integral(x, (x, 0, 1)).transform(x, (u + a, u)) + Integral(a + u, (u, -a, 1 - a)) + >>> Integral(x, (x, 0, 1)).transform(x, (u + a, a)) + Integral(a + u, (a, -u, 1 - u)) + + See Also + ======== + + sympy.concrete.expr_with_limits.ExprWithLimits.variables : Lists the integration variables + as_dummy : Replace integration variables with dummy ones + """ + d = Dummy('d') + + xfree = x.free_symbols.intersection(self.variables) + if len(xfree) > 1: + raise ValueError( + 'F(x) can only contain one of: %s' % self.variables) + xvar = xfree.pop() if xfree else d + + if xvar not in self.variables: + return self + + u = sympify(u) + if isinstance(u, Expr): + ufree = u.free_symbols + if len(ufree) == 0: + raise ValueError(filldedent(''' + f(u) cannot be a constant''')) + if len(ufree) > 1: + raise ValueError(filldedent(''' + When f(u) has more than one free symbol, the one replacing x + must be identified: pass f(u) as (f(u), u)''')) + uvar = ufree.pop() + else: + u, uvar = u + if uvar not in u.free_symbols: + raise ValueError(filldedent(''' + Expecting a tuple (expr, symbol) where symbol identified + a free symbol in expr, but symbol is not in expr's free + symbols.''')) + if not isinstance(uvar, Symbol): + # This probably never evaluates to True + raise ValueError(filldedent(''' + Expecting a tuple (expr, symbol) but didn't get + a symbol; got %s''' % uvar)) + + if x.is_Symbol and u.is_Symbol: + return self.xreplace({x: u}) + + if not x.is_Symbol and not u.is_Symbol: + raise ValueError('either x or u must be a symbol') + + if uvar == xvar: + return self.transform(x, (u.subs(uvar, d), d)).xreplace({d: uvar}) + + if uvar in self.limits: + raise ValueError(filldedent(''' + u must contain the same variable as in x + or a variable that is not already an integration variable''')) + + from sympy.solvers.solvers import solve + if not x.is_Symbol: + F = [x.subs(xvar, d)] + soln = solve(u - x, xvar, check=False) + if not soln: + raise ValueError('no solution for solve(F(x) - f(u), x)') + f = [fi.subs(uvar, d) for fi in soln] + else: + f = [u.subs(uvar, d)] + from sympy.simplify.simplify import posify + pdiff, reps = posify(u - x) + puvar = uvar.subs([(v, k) for k, v in reps.items()]) + soln = [s.subs(reps) for s in solve(pdiff, puvar)] + if not soln: + raise ValueError('no solution for solve(F(x) - f(u), u)') + F = [fi.subs(xvar, d) for fi in soln] + + newfuncs = {(self.function.subs(xvar, fi)*fi.diff(d) + ).subs(d, uvar) for fi in f} + if len(newfuncs) > 1: + raise ValueError(filldedent(''' + The mapping between F(x) and f(u) did not give + a unique integrand.''')) + newfunc = newfuncs.pop() + + def _calc_limit_1(F, a, b): + """ + replace d with a, using subs if possible, otherwise limit + where sign of b is considered + """ + wok = F.subs(d, a) + if wok is S.NaN or wok.is_finite is False and a.is_finite: + return limit(sign(b)*F, d, a) + return wok + + def _calc_limit(a, b): + """ + replace d with a, using subs if possible, otherwise limit + where sign of b is considered + """ + avals = list({_calc_limit_1(Fi, a, b) for Fi in F}) + if len(avals) > 1: + raise ValueError(filldedent(''' + The mapping between F(x) and f(u) did not + give a unique limit.''')) + return avals[0] + + newlimits = [] + for xab in self.limits: + sym = xab[0] + if sym == xvar: + if len(xab) == 3: + a, b = xab[1:] + a, b = _calc_limit(a, b), _calc_limit(b, a) + if fuzzy_bool(a - b > 0): + a, b = b, a + newfunc = -newfunc + newlimits.append((uvar, a, b)) + elif len(xab) == 2: + a = _calc_limit(xab[1], 1) + newlimits.append((uvar, a)) + else: + newlimits.append(uvar) + else: + newlimits.append(xab) + + return self.func(newfunc, *newlimits) + + def doit(self, **hints): + """ + Perform the integration using any hints given. + + Examples + ======== + + >>> from sympy import Piecewise, S + >>> from sympy.abc import x, t + >>> p = x**2 + Piecewise((0, x/t < 0), (1, True)) + >>> p.integrate((t, S(4)/5, 1), (x, -1, 1)) + 1/3 + + See Also + ======== + + sympy.integrals.trigonometry.trigintegrate + sympy.integrals.heurisch.heurisch + sympy.integrals.rationaltools.ratint + as_sum : Approximate the integral using a sum + """ + if not hints.get('integrals', True): + return self + + deep = hints.get('deep', True) + meijerg = hints.get('meijerg', None) + conds = hints.get('conds', 'piecewise') + risch = hints.get('risch', None) + heurisch = hints.get('heurisch', None) + manual = hints.get('manual', None) + if len(list(filter(None, (manual, meijerg, risch, heurisch)))) > 1: + raise ValueError("At most one of manual, meijerg, risch, heurisch can be True") + elif manual: + meijerg = risch = heurisch = False + elif meijerg: + manual = risch = heurisch = False + elif risch: + manual = meijerg = heurisch = False + elif heurisch: + manual = meijerg = risch = False + eval_kwargs = {"meijerg": meijerg, "risch": risch, "manual": manual, "heurisch": heurisch, + "conds": conds} + + if conds not in ('separate', 'piecewise', 'none'): + raise ValueError('conds must be one of "separate", "piecewise", ' + '"none", got: %s' % conds) + + if risch and any(len(xab) > 1 for xab in self.limits): + raise ValueError('risch=True is only allowed for indefinite integrals.') + + # check for the trivial zero + if self.is_zero: + return S.Zero + + # hacks to handle integrals of + # nested summations + from sympy.concrete.summations import Sum + if isinstance(self.function, Sum): + if any(v in self.function.limits[0] for v in self.variables): + raise ValueError('Limit of the sum cannot be an integration variable.') + if any(l.is_infinite for l in self.function.limits[0][1:]): + return self + _i = self + _sum = self.function + return _sum.func(_i.func(_sum.function, *_i.limits).doit(), *_sum.limits).doit() + + # now compute and check the function + function = self.function + + # hack to use a consistent Heaviside(x, 1/2) + function = function.replace( + lambda x: isinstance(x, Heaviside) and x.args[1]*2 != 1, + lambda x: Heaviside(x.args[0])) + + if deep: + function = function.doit(**hints) + if function.is_zero: + return S.Zero + + # hacks to handle special cases + if isinstance(function, MatrixBase): + return function.applyfunc( + lambda f: self.func(f, *self.limits).doit(**hints)) + + if isinstance(function, FormalPowerSeries): + if len(self.limits) > 1: + raise NotImplementedError + xab = self.limits[0] + if len(xab) > 1: + return function.integrate(xab, **eval_kwargs) + else: + return function.integrate(xab[0], **eval_kwargs) + + # There is no trivial answer and special handling + # is done so continue + + # first make sure any definite limits have integration + # variables with matching assumptions + reps = {} + for xab in self.limits: + if len(xab) != 3: + # it makes sense to just make + # all x real but in practice with the + # current state of integration...this + # doesn't work out well + # x = xab[0] + # if x not in reps and not x.is_real: + # reps[x] = Dummy(real=True) + continue + x, a, b = xab + l = (a, b) + if all(i.is_nonnegative for i in l) and not x.is_nonnegative: + d = Dummy(positive=True) + elif all(i.is_nonpositive for i in l) and not x.is_nonpositive: + d = Dummy(negative=True) + elif all(i.is_real for i in l) and not x.is_real: + d = Dummy(real=True) + else: + d = None + if d: + reps[x] = d + if reps: + undo = {v: k for k, v in reps.items()} + did = self.xreplace(reps).doit(**hints) + if isinstance(did, tuple): # when separate=True + did = tuple([i.xreplace(undo) for i in did]) + else: + did = did.xreplace(undo) + return did + + # continue with existing assumptions + undone_limits = [] + # ulj = free symbols of any undone limits' upper and lower limits + ulj = set() + for xab in self.limits: + # compute uli, the free symbols in the + # Upper and Lower limits of limit I + if len(xab) == 1: + uli = set(xab[:1]) + elif len(xab) == 2: + uli = xab[1].free_symbols + elif len(xab) == 3: + uli = xab[1].free_symbols.union(xab[2].free_symbols) + # this integral can be done as long as there is no blocking + # limit that has been undone. An undone limit is blocking if + # it contains an integration variable that is in this limit's + # upper or lower free symbols or vice versa + if xab[0] in ulj or any(v[0] in uli for v in undone_limits): + undone_limits.append(xab) + ulj.update(uli) + function = self.func(*([function] + [xab])) + factored_function = function.factor() + if not isinstance(factored_function, Integral): + function = factored_function + continue + + if function.has(Abs, sign) and ( + (len(xab) < 3 and all(x.is_extended_real for x in xab)) or + (len(xab) == 3 and all(x.is_extended_real and not x.is_infinite for + x in xab[1:]))): + # some improper integrals are better off with Abs + xr = Dummy("xr", real=True) + function = (function.xreplace({xab[0]: xr}) + .rewrite(Piecewise).xreplace({xr: xab[0]})) + elif function.has(Min, Max): + function = function.rewrite(Piecewise) + if (function.has(Piecewise) and + not isinstance(function, Piecewise)): + function = piecewise_fold(function) + if isinstance(function, Piecewise): + if len(xab) == 1: + antideriv = function._eval_integral(xab[0], + **eval_kwargs) + else: + antideriv = self._eval_integral( + function, xab[0], **eval_kwargs) + else: + # There are a number of tradeoffs in using the + # Meijer G method. It can sometimes be a lot faster + # than other methods, and sometimes slower. And + # there are certain types of integrals for which it + # is more likely to work than others. These + # heuristics are incorporated in deciding what + # integration methods to try, in what order. See the + # integrate() docstring for details. + def try_meijerg(function, xab): + ret = None + if len(xab) == 3 and meijerg is not False: + x, a, b = xab + try: + res = meijerint_definite(function, x, a, b) + except NotImplementedError: + _debug('NotImplementedError ' + 'from meijerint_definite') + res = None + if res is not None: + f, cond = res + if conds == 'piecewise': + u = self.func(function, (x, a, b)) + # if Piecewise modifies cond too + # much it may not be recognized by + # _condsimp pattern matching so just + # turn off all evaluation + return Piecewise((f, cond), (u, True), + evaluate=False) + elif conds == 'separate': + if len(self.limits) != 1: + raise ValueError(filldedent(''' + conds=separate not supported in + multiple integrals''')) + ret = f, cond + else: + ret = f + return ret + + meijerg1 = meijerg + if (meijerg is not False and + len(xab) == 3 and xab[1].is_extended_real and xab[2].is_extended_real + and not function.is_Poly and + (xab[1].has(oo, -oo) or xab[2].has(oo, -oo))): + ret = try_meijerg(function, xab) + if ret is not None: + function = ret + continue + meijerg1 = False + # If the special meijerg code did not succeed in + # finding a definite integral, then the code using + # meijerint_indefinite will not either (it might + # find an antiderivative, but the answer is likely + # to be nonsensical). Thus if we are requested to + # only use Meijer G-function methods, we give up at + # this stage. Otherwise we just disable G-function + # methods. + if meijerg1 is False and meijerg is True: + antideriv = None + else: + antideriv = self._eval_integral( + function, xab[0], **eval_kwargs) + if antideriv is None and meijerg is True: + ret = try_meijerg(function, xab) + if ret is not None: + function = ret + continue + + final = hints.get('final', True) + # dotit may be iterated but floor terms making atan and acot + # continuous should only be added in the final round + if (final and not isinstance(antideriv, Integral) and + antideriv is not None): + for atan_term in antideriv.atoms(atan): + atan_arg = atan_term.args[0] + # Checking `atan_arg` to be linear combination of `tan` or `cot` + for tan_part in atan_arg.atoms(tan): + x1 = Dummy('x1') + tan_exp1 = atan_arg.subs(tan_part, x1) + # The coefficient of `tan` should be constant + coeff = tan_exp1.diff(x1) + if x1 not in coeff.free_symbols: + a = tan_part.args[0] + antideriv = antideriv.subs(atan_term, Add(atan_term, + sign(coeff)*pi*floor((a-pi/2)/pi))) + for cot_part in atan_arg.atoms(cot): + x1 = Dummy('x1') + cot_exp1 = atan_arg.subs(cot_part, x1) + # The coefficient of `cot` should be constant + coeff = cot_exp1.diff(x1) + if x1 not in coeff.free_symbols: + a = cot_part.args[0] + antideriv = antideriv.subs(atan_term, Add(atan_term, + sign(coeff)*pi*floor((a)/pi))) + + if antideriv is None: + undone_limits.append(xab) + function = self.func(*([function] + [xab])).factor() + factored_function = function.factor() + if not isinstance(factored_function, Integral): + function = factored_function + continue + else: + if len(xab) == 1: + function = antideriv + else: + if len(xab) == 3: + x, a, b = xab + elif len(xab) == 2: + x, b = xab + a = None + else: + raise NotImplementedError + + if deep: + if isinstance(a, Basic): + a = a.doit(**hints) + if isinstance(b, Basic): + b = b.doit(**hints) + + if antideriv.is_Poly: + gens = list(antideriv.gens) + gens.remove(x) + + antideriv = antideriv.as_expr() + + function = antideriv._eval_interval(x, a, b) + function = Poly(function, *gens) + else: + def is_indef_int(g, x): + return (isinstance(g, Integral) and + any(i == (x,) for i in g.limits)) + + def eval_factored(f, x, a, b): + # _eval_interval for integrals with + # (constant) factors + # a single indefinite integral is assumed + args = [] + for g in Mul.make_args(f): + if is_indef_int(g, x): + args.append(g._eval_interval(x, a, b)) + else: + args.append(g) + return Mul(*args) + + integrals, others, piecewises = [], [], [] + for f in Add.make_args(antideriv): + if any(is_indef_int(g, x) + for g in Mul.make_args(f)): + integrals.append(f) + elif any(isinstance(g, Piecewise) + for g in Mul.make_args(f)): + piecewises.append(piecewise_fold(f)) + else: + others.append(f) + uneval = Add(*[eval_factored(f, x, a, b) + for f in integrals]) + try: + evalued = Add(*others)._eval_interval(x, a, b) + evalued_pw = piecewise_fold(Add(*piecewises))._eval_interval(x, a, b) + function = uneval + evalued + evalued_pw + except NotImplementedError: + # This can happen if _eval_interval depends in a + # complicated way on limits that cannot be computed + undone_limits.append(xab) + function = self.func(*([function] + [xab])) + factored_function = function.factor() + if not isinstance(factored_function, Integral): + function = factored_function + return function + + def _eval_derivative(self, sym): + """Evaluate the derivative of the current Integral object by + differentiating under the integral sign [1], using the Fundamental + Theorem of Calculus [2] when possible. + + Explanation + =========== + + Whenever an Integral is encountered that is equivalent to zero or + has an integrand that is independent of the variable of integration + those integrals are performed. All others are returned as Integral + instances which can be resolved with doit() (provided they are integrable). + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Differentiation_under_the_integral_sign + .. [2] https://en.wikipedia.org/wiki/Fundamental_theorem_of_calculus + + Examples + ======== + + >>> from sympy import Integral + >>> from sympy.abc import x, y + >>> i = Integral(x + y, y, (y, 1, x)) + >>> i.diff(x) + Integral(x + y, (y, x)) + Integral(1, y, (y, 1, x)) + >>> i.doit().diff(x) == i.diff(x).doit() + True + >>> i.diff(y) + 0 + + The previous must be true since there is no y in the evaluated integral: + + >>> i.free_symbols + {x} + >>> i.doit() + 2*x**3/3 - x/2 - 1/6 + + """ + + # differentiate under the integral sign; we do not + # check for regularity conditions (TODO), see issue 4215 + + # get limits and the function + f, limits = self.function, list(self.limits) + + # the order matters if variables of integration appear in the limits + # so work our way in from the outside to the inside. + limit = limits.pop(-1) + if len(limit) == 3: + x, a, b = limit + elif len(limit) == 2: + x, b = limit + a = None + else: + a = b = None + x = limit[0] + + if limits: # f is the argument to an integral + f = self.func(f, *tuple(limits)) + + # assemble the pieces + def _do(f, ab): + dab_dsym = diff(ab, sym) + if not dab_dsym: + return S.Zero + if isinstance(f, Integral): + limits = [(x, x) if (len(l) == 1 and l[0] == x) else l + for l in f.limits] + f = self.func(f.function, *limits) + return f.subs(x, ab)*dab_dsym + + rv = S.Zero + if b is not None: + rv += _do(f, b) + if a is not None: + rv -= _do(f, a) + if len(limit) == 1 and sym == x: + # the dummy variable *is* also the real-world variable + arg = f + rv += arg + else: + # the dummy variable might match sym but it's + # only a dummy and the actual variable is determined + # by the limits, so mask off the variable of integration + # while differentiating + u = Dummy('u') + arg = f.subs(x, u).diff(sym).subs(u, x) + if arg: + rv += self.func(arg, (x, a, b)) + return rv + + def _eval_integral(self, f, x, meijerg=None, risch=None, manual=None, + heurisch=None, conds='piecewise',final=None): + """ + Calculate the anti-derivative to the function f(x). + + Explanation + =========== + + The following algorithms are applied (roughly in this order): + + 1. Simple heuristics (based on pattern matching and integral table): + + - most frequently used functions (e.g. polynomials, products of + trig functions) + + 2. Integration of rational functions: + + - A complete algorithm for integrating rational functions is + implemented (the Lazard-Rioboo-Trager algorithm). The algorithm + also uses the partial fraction decomposition algorithm + implemented in apart() as a preprocessor to make this process + faster. Note that the integral of a rational function is always + elementary, but in general, it may include a RootSum. + + 3. Full Risch algorithm: + + - The Risch algorithm is a complete decision + procedure for integrating elementary functions, which means that + given any elementary function, it will either compute an + elementary antiderivative, or else prove that none exists. + Currently, part of transcendental case is implemented, meaning + elementary integrals containing exponentials, logarithms, and + (soon!) trigonometric functions can be computed. The algebraic + case, e.g., functions containing roots, is much more difficult + and is not implemented yet. + + - If the routine fails (because the integrand is not elementary, or + because a case is not implemented yet), it continues on to the + next algorithms below. If the routine proves that the integrals + is nonelementary, it still moves on to the algorithms below, + because we might be able to find a closed-form solution in terms + of special functions. If risch=True, however, it will stop here. + + 4. The Meijer G-Function algorithm: + + - This algorithm works by first rewriting the integrand in terms of + very general Meijer G-Function (meijerg in SymPy), integrating + it, and then rewriting the result back, if possible. This + algorithm is particularly powerful for definite integrals (which + is actually part of a different method of Integral), since it can + compute closed-form solutions of definite integrals even when no + closed-form indefinite integral exists. But it also is capable + of computing many indefinite integrals as well. + + - Another advantage of this method is that it can use some results + about the Meijer G-Function to give a result in terms of a + Piecewise expression, which allows to express conditionally + convergent integrals. + + - Setting meijerg=True will cause integrate() to use only this + method. + + 5. The "manual integration" algorithm: + + - This algorithm tries to mimic how a person would find an + antiderivative by hand, for example by looking for a + substitution or applying integration by parts. This algorithm + does not handle as many integrands but can return results in a + more familiar form. + + - Sometimes this algorithm can evaluate parts of an integral; in + this case integrate() will try to evaluate the rest of the + integrand using the other methods here. + + - Setting manual=True will cause integrate() to use only this + method. + + 6. The Heuristic Risch algorithm: + + - This is a heuristic version of the Risch algorithm, meaning that + it is not deterministic. This is tried as a last resort because + it can be very slow. It is still used because not enough of the + full Risch algorithm is implemented, so that there are still some + integrals that can only be computed using this method. The goal + is to implement enough of the Risch and Meijer G-function methods + so that this can be deleted. + + Setting heurisch=True will cause integrate() to use only this + method. Set heurisch=False to not use it. + + """ + + from sympy.integrals.risch import risch_integrate, NonElementaryIntegral + from sympy.integrals.manualintegrate import manualintegrate + + if risch: + try: + return risch_integrate(f, x, conds=conds) + except NotImplementedError: + return None + + if manual: + try: + result = manualintegrate(f, x) + if result is not None and result.func != Integral: + return result + except (ValueError, PolynomialError): + pass + + eval_kwargs = {"meijerg": meijerg, "risch": risch, "manual": manual, + "heurisch": heurisch, "conds": conds} + + # if it is a poly(x) then let the polynomial integrate itself (fast) + # + # It is important to make this check first, otherwise the other code + # will return a SymPy expression instead of a Polynomial. + # + # see Polynomial for details. + if isinstance(f, Poly) and not (manual or meijerg or risch): + # Note: this is deprecated, but the deprecation warning is already + # issued in the Integral constructor. + return f.integrate(x) + + # Piecewise antiderivatives need to call special integrate. + if isinstance(f, Piecewise): + return f.piecewise_integrate(x, **eval_kwargs) + + # let's cut it short if `f` does not depend on `x`; if + # x is only a dummy, that will be handled below + if not f.has(x): + return f*x + + # try to convert to poly(x) and then integrate if successful (fast) + poly = f.as_poly(x) + if poly is not None and not (manual or meijerg or risch): + return poly.integrate().as_expr() + + if risch is not False: + try: + result, i = risch_integrate(f, x, separate_integral=True, + conds=conds) + except NotImplementedError: + pass + else: + if i: + # There was a nonelementary integral. Try integrating it. + + # if no part of the NonElementaryIntegral is integrated by + # the Risch algorithm, then use the original function to + # integrate, instead of re-written one + if result == 0: + return NonElementaryIntegral(f, x).doit(risch=False) + else: + return result + i.doit(risch=False) + else: + return result + + # since Integral(f=g1+g2+...) == Integral(g1) + Integral(g2) + ... + # we are going to handle Add terms separately, + # if `f` is not Add -- we only have one term + + # Note that in general, this is a bad idea, because Integral(g1) + + # Integral(g2) might not be computable, even if Integral(g1 + g2) is. + # For example, Integral(x**x + x**x*log(x)). But many heuristics only + # work term-wise. So we compute this step last, after trying + # risch_integrate. We also try risch_integrate again in this loop, + # because maybe the integral is a sum of an elementary part and a + # nonelementary part (like erf(x) + exp(x)). risch_integrate() is + # quite fast, so this is acceptable. + from sympy.simplify.fu import sincos_to_sum + parts = [] + args = Add.make_args(f) + for g in args: + coeff, g = g.as_independent(x) + + # g(x) = const + if g is S.One and not meijerg: + parts.append(coeff*x) + continue + + # g(x) = expr + O(x**n) + order_term = g.getO() + + if order_term is not None: + h = self._eval_integral(g.removeO(), x, **eval_kwargs) + + if h is not None: + h_order_expr = self._eval_integral(order_term.expr, x, **eval_kwargs) + + if h_order_expr is not None: + h_order_term = order_term.func( + h_order_expr, *order_term.variables) + parts.append(coeff*(h + h_order_term)) + continue + + # NOTE: if there is O(x**n) and we fail to integrate then + # there is no point in trying other methods because they + # will fail, too. + return None + + # c + # g(x) = (a*x+b) + if g.is_Pow and not g.exp.has(x) and not meijerg: + a = Wild('a', exclude=[x]) + b = Wild('b', exclude=[x]) + + M = g.base.match(a*x + b) + + if M is not None: + if g.exp == -1: + h = log(g.base) + elif conds != 'piecewise': + h = g.base**(g.exp + 1) / (g.exp + 1) + else: + h1 = log(g.base) + h2 = g.base**(g.exp + 1) / (g.exp + 1) + h = Piecewise((h2, Ne(g.exp, -1)), (h1, True)) + + parts.append(coeff * h / M[a]) + continue + + # poly(x) + # g(x) = ------- + # poly(x) + if g.is_rational_function(x) and not (manual or meijerg or risch): + parts.append(coeff * ratint(g, x)) + continue + + if not (manual or meijerg or risch): + # g(x) = Mul(trig) + h = trigintegrate(g, x, conds=conds) + if h is not None: + parts.append(coeff * h) + continue + + # g(x) has at least a DiracDelta term + h = deltaintegrate(g, x) + if h is not None: + parts.append(coeff * h) + continue + + from .singularityfunctions import singularityintegrate + # g(x) has at least a Singularity Function term + h = singularityintegrate(g, x) + if h is not None: + parts.append(coeff * h) + continue + + # Try risch again. + if risch is not False: + try: + h, i = risch_integrate(g, x, + separate_integral=True, conds=conds) + except NotImplementedError: + h = None + else: + if i: + h = h + i.doit(risch=False) + + parts.append(coeff*h) + continue + + # fall back to heurisch + if heurisch is not False: + from sympy.integrals.heurisch import (heurisch as heurisch_, + heurisch_wrapper) + try: + if conds == 'piecewise': + h = heurisch_wrapper(g, x, hints=[]) + else: + h = heurisch_(g, x, hints=[]) + except PolynomialError: + # XXX: this exception means there is a bug in the + # implementation of heuristic Risch integration + # algorithm. + h = None + else: + h = None + + if meijerg is not False and h is None: + # rewrite using G functions + try: + h = meijerint_indefinite(g, x) + except NotImplementedError: + _debug('NotImplementedError from meijerint_definite') + if h is not None: + parts.append(coeff * h) + continue + + if h is None and manual is not False: + try: + result = manualintegrate(g, x) + if result is not None and not isinstance(result, Integral): + if result.has(Integral) and not manual: + # Try to have other algorithms do the integrals + # manualintegrate can't handle, + # unless we were asked to use manual only. + # Keep the rest of eval_kwargs in case another + # method was set to False already + new_eval_kwargs = eval_kwargs + new_eval_kwargs["manual"] = False + new_eval_kwargs["final"] = False + result = result.func(*[ + arg.doit(**new_eval_kwargs) if + arg.has(Integral) else arg + for arg in result.args + ]).expand(multinomial=False, + log=False, + power_exp=False, + power_base=False) + if not result.has(Integral): + parts.append(coeff * result) + continue + except (ValueError, PolynomialError): + # can't handle some SymPy expressions + pass + + # if we failed maybe it was because we had + # a product that could have been expanded, + # so let's try an expansion of the whole + # thing before giving up; we don't try this + # at the outset because there are things + # that cannot be solved unless they are + # NOT expanded e.g., x**x*(1+log(x)). There + # should probably be a checker somewhere in this + # routine to look for such cases and try to do + # collection on the expressions if they are already + # in an expanded form + if not h and len(args) == 1: + f = sincos_to_sum(f).expand(mul=True, deep=False) + if f.is_Add: + # Note: risch will be identical on the expanded + # expression, but maybe it will be able to pick out parts, + # like x*(exp(x) + erf(x)). + return self._eval_integral(f, x, **eval_kwargs) + + if h is not None: + parts.append(coeff * h) + else: + return None + + return Add(*parts) + + def _eval_lseries(self, x, logx=None, cdir=0): + expr = self.as_dummy() + symb = x + for l in expr.limits: + if x in l[1:]: + symb = l[0] + break + for term in expr.function.lseries(symb, logx): + yield integrate(term, *expr.limits) + + def _eval_nseries(self, x, n, logx=None, cdir=0): + symb = x + for l in self.limits: + if x in l[1:]: + symb = l[0] + break + terms, order = self.function.nseries( + x=symb, n=n, logx=logx).as_coeff_add(Order) + order = [o.subs(symb, x) for o in order] + return integrate(terms, *self.limits) + Add(*order)*x + + def _eval_as_leading_term(self, x, logx, cdir): + series_gen = self.args[0].lseries(x) + for leading_term in series_gen: + if leading_term != 0: + break + return integrate(leading_term, *self.args[1:]) + + def _eval_simplify(self, **kwargs): + expr = factor_terms(self) + if isinstance(expr, Integral): + from sympy.simplify.simplify import simplify + return expr.func(*[simplify(i, **kwargs) for i in expr.args]) + return expr.simplify(**kwargs) + + def as_sum(self, n=None, method="midpoint", evaluate=True): + """ + Approximates a definite integral by a sum. + + Parameters + ========== + + n : + The number of subintervals to use, optional. + method : + One of: 'left', 'right', 'midpoint', 'trapezoid'. + evaluate : bool + If False, returns an unevaluated Sum expression. The default + is True, evaluate the sum. + + Notes + ===== + + These methods of approximate integration are described in [1]. + + Examples + ======== + + >>> from sympy import Integral, sin, sqrt + >>> from sympy.abc import x, n + >>> e = Integral(sin(x), (x, 3, 7)) + >>> e + Integral(sin(x), (x, 3, 7)) + + For demonstration purposes, this interval will only be split into 2 + regions, bounded by [3, 5] and [5, 7]. + + The left-hand rule uses function evaluations at the left of each + interval: + + >>> e.as_sum(2, 'left') + 2*sin(5) + 2*sin(3) + + The midpoint rule uses evaluations at the center of each interval: + + >>> e.as_sum(2, 'midpoint') + 2*sin(4) + 2*sin(6) + + The right-hand rule uses function evaluations at the right of each + interval: + + >>> e.as_sum(2, 'right') + 2*sin(5) + 2*sin(7) + + The trapezoid rule uses function evaluations on both sides of the + intervals. This is equivalent to taking the average of the left and + right hand rule results: + + >>> s = e.as_sum(2, 'trapezoid') + >>> s + 2*sin(5) + sin(3) + sin(7) + >>> (e.as_sum(2, 'left') + e.as_sum(2, 'right'))/2 == s + True + + Here, the discontinuity at x = 0 can be avoided by using the + midpoint or right-hand method: + + >>> e = Integral(1/sqrt(x), (x, 0, 1)) + >>> e.as_sum(5).n(4) + 1.730 + >>> e.as_sum(10).n(4) + 1.809 + >>> e.doit().n(4) # the actual value is 2 + 2.000 + + The left- or trapezoid method will encounter the discontinuity and + return infinity: + + >>> e.as_sum(5, 'left') + zoo + + The number of intervals can be symbolic. If omitted, a dummy symbol + will be used for it. + + >>> e = Integral(x**2, (x, 0, 2)) + >>> e.as_sum(n, 'right').expand() + 8/3 + 4/n + 4/(3*n**2) + + This shows that the midpoint rule is more accurate, as its error + term decays as the square of n: + + >>> e.as_sum(method='midpoint').expand() + 8/3 - 2/(3*_n**2) + + A symbolic sum is returned with evaluate=False: + + >>> e.as_sum(n, 'midpoint', evaluate=False) + 2*Sum((2*_k/n - 1/n)**2, (_k, 1, n))/n + + See Also + ======== + + Integral.doit : Perform the integration using any hints + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Riemann_sum#Riemann_summation_methods + """ + + from sympy.concrete.summations import Sum + limits = self.limits + if len(limits) > 1: + raise NotImplementedError( + "Multidimensional midpoint rule not implemented yet") + else: + limit = limits[0] + if (len(limit) != 3 or limit[1].is_finite is False or + limit[2].is_finite is False): + raise ValueError("Expecting a definite integral over " + "a finite interval.") + if n is None: + n = Dummy('n', integer=True, positive=True) + else: + n = sympify(n) + if (n.is_positive is False or n.is_integer is False or + n.is_finite is False): + raise ValueError("n must be a positive integer, got %s" % n) + x, a, b = limit + dx = (b - a)/n + k = Dummy('k', integer=True, positive=True) + f = self.function + + if method == "left": + result = dx*Sum(f.subs(x, a + (k-1)*dx), (k, 1, n)) + elif method == "right": + result = dx*Sum(f.subs(x, a + k*dx), (k, 1, n)) + elif method == "midpoint": + result = dx*Sum(f.subs(x, a + k*dx - dx/2), (k, 1, n)) + elif method == "trapezoid": + result = dx*((f.subs(x, a) + f.subs(x, b))/2 + + Sum(f.subs(x, a + k*dx), (k, 1, n - 1))) + else: + raise ValueError("Unknown method %s" % method) + return result.doit() if evaluate else result + + def principal_value(self, **kwargs): + """ + Compute the Cauchy Principal Value of the definite integral of a real function in the given interval + on the real axis. + + Explanation + =========== + + In mathematics, the Cauchy principal value, is a method for assigning values to certain improper + integrals which would otherwise be undefined. + + Examples + ======== + + >>> from sympy import Integral, oo + >>> from sympy.abc import x + >>> Integral(x+1, (x, -oo, oo)).principal_value() + oo + >>> f = 1 / (x**3) + >>> Integral(f, (x, -oo, oo)).principal_value() + 0 + >>> Integral(f, (x, -10, 10)).principal_value() + 0 + >>> Integral(f, (x, -10, oo)).principal_value() + Integral(f, (x, -oo, 10)).principal_value() + 0 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Cauchy_principal_value + .. [2] https://mathworld.wolfram.com/CauchyPrincipalValue.html + """ + if len(self.limits) != 1 or len(list(self.limits[0])) != 3: + raise ValueError("You need to insert a variable, lower_limit, and upper_limit correctly to calculate " + "cauchy's principal value") + x, a, b = self.limits[0] + if not (a.is_comparable and b.is_comparable and a <= b): + raise ValueError("The lower_limit must be smaller than or equal to the upper_limit to calculate " + "cauchy's principal value. Also, a and b need to be comparable.") + if a == b: + return S.Zero + + from sympy.calculus.singularities import singularities + + r = Dummy('r') + f = self.function + singularities_list = [s for s in singularities(f, x) if s.is_comparable and a <= s <= b] + for i in singularities_list: + if i in (a, b): + raise ValueError( + 'The principal value is not defined in the given interval due to singularity at %d.' % (i)) + F = integrate(f, x, **kwargs) + if F.has(Integral): + return self + if a is -oo and b is oo: + I = limit(F - F.subs(x, -x), x, oo) + else: + I = limit(F, x, b, '-') - limit(F, x, a, '+') + for s in singularities_list: + I += limit(((F.subs(x, s - r)) - F.subs(x, s + r)), r, 0, '+') + return I + + + +def integrate(*args, meijerg=None, conds='piecewise', risch=None, heurisch=None, manual=None, **kwargs): + """integrate(f, var, ...) + + .. deprecated:: 1.6 + + Using ``integrate()`` with :class:`~.Poly` is deprecated. Use + :meth:`.Poly.integrate` instead. See :ref:`deprecated-integrate-poly`. + + Explanation + =========== + + Compute definite or indefinite integral of one or more variables + using Risch-Norman algorithm and table lookup. This procedure is + able to handle elementary algebraic and transcendental functions + and also a huge class of special functions, including Airy, + Bessel, Whittaker and Lambert. + + var can be: + + - a symbol -- indefinite integration + - a tuple (symbol, a) -- indefinite integration with result + given with ``a`` replacing ``symbol`` + - a tuple (symbol, a, b) -- definite integration + + Several variables can be specified, in which case the result is + multiple integration. (If var is omitted and the integrand is + univariate, the indefinite integral in that variable will be performed.) + + Indefinite integrals are returned without terms that are independent + of the integration variables. (see examples) + + Definite improper integrals often entail delicate convergence + conditions. Pass conds='piecewise', 'separate' or 'none' to have + these returned, respectively, as a Piecewise function, as a separate + result (i.e. result will be a tuple), or not at all (default is + 'piecewise'). + + **Strategy** + + SymPy uses various approaches to definite integration. One method is to + find an antiderivative for the integrand, and then use the fundamental + theorem of calculus. Various functions are implemented to integrate + polynomial, rational and trigonometric functions, and integrands + containing DiracDelta terms. + + SymPy also implements the part of the Risch algorithm, which is a decision + procedure for integrating elementary functions, i.e., the algorithm can + either find an elementary antiderivative, or prove that one does not + exist. There is also a (very successful, albeit somewhat slow) general + implementation of the heuristic Risch algorithm. This algorithm will + eventually be phased out as more of the full Risch algorithm is + implemented. See the docstring of Integral._eval_integral() for more + details on computing the antiderivative using algebraic methods. + + The option risch=True can be used to use only the (full) Risch algorithm. + This is useful if you want to know if an elementary function has an + elementary antiderivative. If the indefinite Integral returned by this + function is an instance of NonElementaryIntegral, that means that the + Risch algorithm has proven that integral to be non-elementary. Note that + by default, additional methods (such as the Meijer G method outlined + below) are tried on these integrals, as they may be expressible in terms + of special functions, so if you only care about elementary answers, use + risch=True. Also note that an unevaluated Integral returned by this + function is not necessarily a NonElementaryIntegral, even with risch=True, + as it may just be an indication that the particular part of the Risch + algorithm needed to integrate that function is not yet implemented. + + Another family of strategies comes from re-writing the integrand in + terms of so-called Meijer G-functions. Indefinite integrals of a + single G-function can always be computed, and the definite integral + of a product of two G-functions can be computed from zero to + infinity. Various strategies are implemented to rewrite integrands + as G-functions, and use this information to compute integrals (see + the ``meijerint`` module). + + The option manual=True can be used to use only an algorithm that tries + to mimic integration by hand. This algorithm does not handle as many + integrands as the other algorithms implemented but may return results in + a more familiar form. The ``manualintegrate`` module has functions that + return the steps used (see the module docstring for more information). + + In general, the algebraic methods work best for computing + antiderivatives of (possibly complicated) combinations of elementary + functions. The G-function methods work best for computing definite + integrals from zero to infinity of moderately complicated + combinations of special functions, or indefinite integrals of very + simple combinations of special functions. + + The strategy employed by the integration code is as follows: + + - If computing a definite integral, and both limits are real, + and at least one limit is +- oo, try the G-function method of + definite integration first. + + - Try to find an antiderivative, using all available methods, ordered + by performance (that is try fastest method first, slowest last; in + particular polynomial integration is tried first, Meijer + G-functions second to last, and heuristic Risch last). + + - If still not successful, try G-functions irrespective of the + limits. + + The option meijerg=True, False, None can be used to, respectively: + always use G-function methods and no others, never use G-function + methods, or use all available methods (in order as described above). + It defaults to None. + + Examples + ======== + + >>> from sympy import integrate, log, exp, oo + >>> from sympy.abc import a, x, y + + >>> integrate(x*y, x) + x**2*y/2 + + >>> integrate(log(x), x) + x*log(x) - x + + >>> integrate(log(x), (x, 1, a)) + a*log(a) - a + 1 + + >>> integrate(x) + x**2/2 + + Terms that are independent of x are dropped by indefinite integration: + + >>> from sympy import sqrt + >>> integrate(sqrt(1 + x), (x, 0, x)) + 2*(x + 1)**(3/2)/3 - 2/3 + >>> integrate(sqrt(1 + x), x) + 2*(x + 1)**(3/2)/3 + + >>> integrate(x*y) + Traceback (most recent call last): + ... + ValueError: specify integration variables to integrate x*y + + Note that ``integrate(x)`` syntax is meant only for convenience + in interactive sessions and should be avoided in library code. + + >>> integrate(x**a*exp(-x), (x, 0, oo)) # same as conds='piecewise' + Piecewise((gamma(a + 1), re(a) > -1), + (Integral(x**a*exp(-x), (x, 0, oo)), True)) + + >>> integrate(x**a*exp(-x), (x, 0, oo), conds='none') + gamma(a + 1) + + >>> integrate(x**a*exp(-x), (x, 0, oo), conds='separate') + (gamma(a + 1), re(a) > -1) + + See Also + ======== + + Integral, Integral.doit + + """ + doit_flags = { + 'deep': False, + 'meijerg': meijerg, + 'conds': conds, + 'risch': risch, + 'heurisch': heurisch, + 'manual': manual + } + + integral = Integral(*args, **kwargs) + + if isinstance(integral, Integral): + return integral.doit(**doit_flags) + else: + new_args = [a.doit(**doit_flags) if isinstance(a, Integral) else a + for a in integral.args] + return integral.func(*new_args) + +def line_integrate(field, curve, vars): + """line_integrate(field, Curve, variables) + + Compute the line integral. + + Examples + ======== + + >>> from sympy import Curve, line_integrate, E, ln + >>> from sympy.abc import x, y, t + >>> C = Curve([E**t + 1, E**t - 1], (t, 0, ln(2))) + >>> line_integrate(x + y, C, [x, y]) + 3*sqrt(2) + + See Also + ======== + + sympy.integrals.integrals.integrate, Integral + """ + from sympy.geometry import Curve + F = sympify(field) + if not F: + raise ValueError( + "Expecting function specifying field as first argument.") + if not isinstance(curve, Curve): + raise ValueError("Expecting Curve entity as second argument.") + if not is_sequence(vars): + raise ValueError("Expecting ordered iterable for variables.") + if len(curve.functions) != len(vars): + raise ValueError("Field variable size does not match curve dimension.") + + if curve.parameter in vars: + raise ValueError("Curve parameter clashes with field parameters.") + + # Calculate derivatives for line parameter functions + # F(r) -> F(r(t)) and finally F(r(t)*r'(t)) + Ft = F + dldt = 0 + for i, var in enumerate(vars): + _f = curve.functions[i] + _dn = diff(_f, curve.parameter) + # ...arc length + dldt = dldt + (_dn * _dn) + Ft = Ft.subs(var, _f) + Ft = Ft * sqrt(dldt) + + integral = Integral(Ft, curve.limits).doit(deep=False) + return integral + + +### Property function dispatching ### + +@shape.register(Integral) +def _(expr): + return shape(expr.function) + +# Delayed imports +from .deltafunctions import deltaintegrate +from .meijerint import meijerint_definite, meijerint_indefinite, _debug +from .trigonometry import trigintegrate diff --git a/.venv/lib/python3.13/site-packages/sympy/integrals/intpoly.py b/.venv/lib/python3.13/site-packages/sympy/integrals/intpoly.py new file mode 100644 index 0000000000000000000000000000000000000000..38fd071183fb2192f4c1443d04c8f0ecfb6cc4ea --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/integrals/intpoly.py @@ -0,0 +1,1302 @@ +""" +Module to implement integration of uni/bivariate polynomials over +2D Polytopes and uni/bi/trivariate polynomials over 3D Polytopes. + +Uses evaluation techniques as described in Chin et al. (2015) [1]. + + +References +=========== + +.. [1] Chin, Eric B., Jean B. Lasserre, and N. Sukumar. "Numerical integration +of homogeneous functions on convex and nonconvex polygons and polyhedra." +Computational Mechanics 56.6 (2015): 967-981 + +PDF link : http://dilbert.engr.ucdavis.edu/~suku/quadrature/cls-integration.pdf +""" + +from functools import cmp_to_key + +from sympy.abc import x, y, z +from sympy.core import S, diff, Expr, Symbol +from sympy.core.sympify import _sympify +from sympy.geometry import Segment2D, Polygon, Point, Point2D +from sympy.polys.polytools import LC, gcd_list, degree_list, Poly +from sympy.simplify.simplify import nsimplify + + +def polytope_integrate(poly, expr=None, *, clockwise=False, max_degree=None): + """Integrates polynomials over 2/3-Polytopes. + + Explanation + =========== + + This function accepts the polytope in ``poly`` and the function in ``expr`` + (uni/bi/trivariate polynomials are implemented) and returns + the exact integral of ``expr`` over ``poly``. + + Parameters + ========== + + poly : The input Polygon. + + expr : The input polynomial. + + clockwise : Binary value to sort input points of 2-Polytope clockwise.(Optional) + + max_degree : The maximum degree of any monomial of the input polynomial.(Optional) + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import Point, Polygon + >>> from sympy.integrals.intpoly import polytope_integrate + >>> polygon = Polygon(Point(0, 0), Point(0, 1), Point(1, 1), Point(1, 0)) + >>> polys = [1, x, y, x*y, x**2*y, x*y**2] + >>> expr = x*y + >>> polytope_integrate(polygon, expr) + 1/4 + >>> polytope_integrate(polygon, polys, max_degree=3) + {1: 1, x: 1/2, y: 1/2, x*y: 1/4, x*y**2: 1/6, x**2*y: 1/6} + """ + if clockwise: + if isinstance(poly, Polygon): + poly = Polygon(*point_sort(poly.vertices), evaluate=False) + else: + raise TypeError("clockwise=True works for only 2-Polytope" + "V-representation input") + + if isinstance(poly, Polygon): + # For Vertex Representation(2D case) + hp_params = hyperplane_parameters(poly) + facets = poly.sides + elif len(poly[0]) == 2: + # For Hyperplane Representation(2D case) + plen = len(poly) + if len(poly[0][0]) == 2: + intersections = [intersection(poly[(i - 1) % plen], poly[i], + "plane2D") + for i in range(0, plen)] + hp_params = poly + lints = len(intersections) + facets = [Segment2D(intersections[i], + intersections[(i + 1) % lints]) + for i in range(lints)] + else: + raise NotImplementedError("Integration for H-representation 3D" + "case not implemented yet.") + else: + # For Vertex Representation(3D case) + vertices = poly[0] + facets = poly[1:] + hp_params = hyperplane_parameters(facets, vertices) + + if max_degree is None: + if expr is None: + raise TypeError('Input expression must be a valid SymPy expression') + return main_integrate3d(expr, facets, vertices, hp_params) + + if max_degree is not None: + result = {} + if expr is not None: + f_expr = [] + for e in expr: + _ = decompose(e) + if len(_) == 1 and not _.popitem()[0]: + f_expr.append(e) + elif Poly(e).total_degree() <= max_degree: + f_expr.append(e) + expr = f_expr + + if not isinstance(expr, list) and expr is not None: + raise TypeError('Input polynomials must be list of expressions') + + if len(hp_params[0][0]) == 3: + result_dict = main_integrate3d(0, facets, vertices, hp_params, + max_degree) + else: + result_dict = main_integrate(0, facets, hp_params, max_degree) + + if expr is None: + return result_dict + + for poly in expr: + poly = _sympify(poly) + if poly not in result: + if poly.is_zero: + result[S.Zero] = S.Zero + continue + integral_value = S.Zero + monoms = decompose(poly, separate=True) + for monom in monoms: + monom = nsimplify(monom) + coeff, m = strip(monom) + integral_value += result_dict[m] * coeff + result[poly] = integral_value + return result + + if expr is None: + raise TypeError('Input expression must be a valid SymPy expression') + + return main_integrate(expr, facets, hp_params) + + +def strip(monom): + if monom.is_zero: + return S.Zero, S.Zero + elif monom.is_number: + return monom, S.One + else: + coeff = LC(monom) + return coeff, monom / coeff + +def _polynomial_integrate(polynomials, facets, hp_params): + dims = (x, y) + dim_length = len(dims) + integral_value = S.Zero + for deg in polynomials: + poly_contribute = S.Zero + facet_count = 0 + for hp in hp_params: + value_over_boundary = integration_reduction(facets, + facet_count, + hp[0], hp[1], + polynomials[deg], + dims, deg) + poly_contribute += value_over_boundary * (hp[1] / norm(hp[0])) + facet_count += 1 + poly_contribute /= (dim_length + deg) + integral_value += poly_contribute + + return integral_value + + +def main_integrate3d(expr, facets, vertices, hp_params, max_degree=None): + """Function to translate the problem of integrating uni/bi/tri-variate + polynomials over a 3-Polytope to integrating over its faces. + This is done using Generalized Stokes' Theorem and Euler's Theorem. + + Parameters + ========== + + expr : + The input polynomial. + facets : + Faces of the 3-Polytope(expressed as indices of `vertices`). + vertices : + Vertices that constitute the Polytope. + hp_params : + Hyperplane Parameters of the facets. + max_degree : optional + Max degree of constituent monomial in given list of polynomial. + + Examples + ======== + + >>> from sympy.integrals.intpoly import main_integrate3d, \ + hyperplane_parameters + >>> cube = [[(0, 0, 0), (0, 0, 5), (0, 5, 0), (0, 5, 5), (5, 0, 0),\ + (5, 0, 5), (5, 5, 0), (5, 5, 5)],\ + [2, 6, 7, 3], [3, 7, 5, 1], [7, 6, 4, 5], [1, 5, 4, 0],\ + [3, 1, 0, 2], [0, 4, 6, 2]] + >>> vertices = cube[0] + >>> faces = cube[1:] + >>> hp_params = hyperplane_parameters(faces, vertices) + >>> main_integrate3d(1, faces, vertices, hp_params) + -125 + """ + result = {} + dims = (x, y, z) + dim_length = len(dims) + if max_degree: + grad_terms = gradient_terms(max_degree, 3) + flat_list = [term for z_terms in grad_terms + for x_term in z_terms + for term in x_term] + + for term in flat_list: + result[term[0]] = 0 + + for facet_count, hp in enumerate(hp_params): + a, b = hp[0], hp[1] + x0 = vertices[facets[facet_count][0]] + + for i, monom in enumerate(flat_list): + # Every monomial is a tuple : + # (term, x_degree, y_degree, z_degree, value over boundary) + expr, x_d, y_d, z_d, z_index, y_index, x_index, _ = monom + degree = x_d + y_d + z_d + if b.is_zero: + value_over_face = S.Zero + else: + value_over_face = \ + integration_reduction_dynamic(facets, facet_count, a, + b, expr, degree, dims, + x_index, y_index, + z_index, x0, grad_terms, + i, vertices, hp) + monom[7] = value_over_face + result[expr] += value_over_face * \ + (b / norm(a)) / (dim_length + x_d + y_d + z_d) + return result + else: + integral_value = S.Zero + polynomials = decompose(expr) + for deg in polynomials: + poly_contribute = S.Zero + facet_count = 0 + for i, facet in enumerate(facets): + hp = hp_params[i] + if hp[1].is_zero: + continue + pi = polygon_integrate(facet, hp, i, facets, vertices, expr, deg) + poly_contribute += pi *\ + (hp[1] / norm(tuple(hp[0]))) + facet_count += 1 + poly_contribute /= (dim_length + deg) + integral_value += poly_contribute + return integral_value + + +def main_integrate(expr, facets, hp_params, max_degree=None): + """Function to translate the problem of integrating univariate/bivariate + polynomials over a 2-Polytope to integrating over its boundary facets. + This is done using Generalized Stokes's Theorem and Euler's Theorem. + + Parameters + ========== + + expr : + The input polynomial. + facets : + Facets(Line Segments) of the 2-Polytope. + hp_params : + Hyperplane Parameters of the facets. + max_degree : optional + The maximum degree of any monomial of the input polynomial. + + >>> from sympy.abc import x, y + >>> from sympy.integrals.intpoly import main_integrate,\ + hyperplane_parameters + >>> from sympy import Point, Polygon + >>> triangle = Polygon(Point(0, 3), Point(5, 3), Point(1, 1)) + >>> facets = triangle.sides + >>> hp_params = hyperplane_parameters(triangle) + >>> main_integrate(x**2 + y**2, facets, hp_params) + 325/6 + """ + dims = (x, y) + dim_length = len(dims) + result = {} + + if max_degree: + grad_terms = [[0, 0, 0, 0]] + gradient_terms(max_degree) + + for facet_count, hp in enumerate(hp_params): + a, b = hp[0], hp[1] + x0 = facets[facet_count].points[0] + + for i, monom in enumerate(grad_terms): + # Every monomial is a tuple : + # (term, x_degree, y_degree, value over boundary) + m, x_d, y_d, _ = monom + value = result.get(m, None) + degree = S.Zero + if b.is_zero: + value_over_boundary = S.Zero + else: + degree = x_d + y_d + value_over_boundary = \ + integration_reduction_dynamic(facets, facet_count, a, + b, m, degree, dims, x_d, + y_d, max_degree, x0, + grad_terms, i) + monom[3] = value_over_boundary + if value is not None: + result[m] += value_over_boundary * \ + (b / norm(a)) / (dim_length + degree) + else: + result[m] = value_over_boundary * \ + (b / norm(a)) / (dim_length + degree) + return result + else: + if not isinstance(expr, list): + polynomials = decompose(expr) + return _polynomial_integrate(polynomials, facets, hp_params) + else: + return {e: _polynomial_integrate(decompose(e), facets, hp_params) for e in expr} + + +def polygon_integrate(facet, hp_param, index, facets, vertices, expr, degree): + """Helper function to integrate the input uni/bi/trivariate polynomial + over a certain face of the 3-Polytope. + + Parameters + ========== + + facet : + Particular face of the 3-Polytope over which ``expr`` is integrated. + index : + The index of ``facet`` in ``facets``. + facets : + Faces of the 3-Polytope(expressed as indices of `vertices`). + vertices : + Vertices that constitute the facet. + expr : + The input polynomial. + degree : + Degree of ``expr``. + + Examples + ======== + + >>> from sympy.integrals.intpoly import polygon_integrate + >>> cube = [[(0, 0, 0), (0, 0, 5), (0, 5, 0), (0, 5, 5), (5, 0, 0),\ + (5, 0, 5), (5, 5, 0), (5, 5, 5)],\ + [2, 6, 7, 3], [3, 7, 5, 1], [7, 6, 4, 5], [1, 5, 4, 0],\ + [3, 1, 0, 2], [0, 4, 6, 2]] + >>> facet = cube[1] + >>> facets = cube[1:] + >>> vertices = cube[0] + >>> polygon_integrate(facet, [(0, 1, 0), 5], 0, facets, vertices, 1, 0) + -25 + """ + expr = S(expr) + if expr.is_zero: + return S.Zero + result = S.Zero + x0 = vertices[facet[0]] + facet_len = len(facet) + for i, fac in enumerate(facet): + side = (vertices[fac], vertices[facet[(i + 1) % facet_len]]) + result += distance_to_side(x0, side, hp_param[0]) *\ + lineseg_integrate(facet, i, side, expr, degree) + if not expr.is_number: + expr = diff(expr, x) * x0[0] + diff(expr, y) * x0[1] +\ + diff(expr, z) * x0[2] + result += polygon_integrate(facet, hp_param, index, facets, vertices, + expr, degree - 1) + result /= (degree + 2) + return result + + +def distance_to_side(point, line_seg, A): + """Helper function to compute the signed distance between given 3D point + and a line segment. + + Parameters + ========== + + point : 3D Point + line_seg : Line Segment + + Examples + ======== + + >>> from sympy.integrals.intpoly import distance_to_side + >>> point = (0, 0, 0) + >>> distance_to_side(point, [(0, 0, 1), (0, 1, 0)], (1, 0, 0)) + -sqrt(2)/2 + """ + x1, x2 = line_seg + rev_normal = [-1 * S(i)/norm(A) for i in A] + vector = [x2[i] - x1[i] for i in range(0, 3)] + vector = [vector[i]/norm(vector) for i in range(0, 3)] + + n_side = cross_product((0, 0, 0), rev_normal, vector) + vectorx0 = [line_seg[0][i] - point[i] for i in range(0, 3)] + dot_product = sum(vectorx0[i] * n_side[i] for i in range(0, 3)) + + return dot_product + + +def lineseg_integrate(polygon, index, line_seg, expr, degree): + """Helper function to compute the line integral of ``expr`` over ``line_seg``. + + Parameters + =========== + + polygon : + Face of a 3-Polytope. + index : + Index of line_seg in polygon. + line_seg : + Line Segment. + + Examples + ======== + + >>> from sympy.integrals.intpoly import lineseg_integrate + >>> polygon = [(0, 5, 0), (5, 5, 0), (5, 5, 5), (0, 5, 5)] + >>> line_seg = [(0, 5, 0), (5, 5, 0)] + >>> lineseg_integrate(polygon, 0, line_seg, 1, 0) + 5 + """ + expr = _sympify(expr) + if expr.is_zero: + return S.Zero + result = S.Zero + x0 = line_seg[0] + distance = norm(tuple([line_seg[1][i] - line_seg[0][i] for i in + range(3)])) + if isinstance(expr, Expr): + expr_dict = {x: line_seg[1][0], + y: line_seg[1][1], + z: line_seg[1][2]} + result += distance * expr.subs(expr_dict) + else: + result += distance * expr + + expr = diff(expr, x) * x0[0] + diff(expr, y) * x0[1] +\ + diff(expr, z) * x0[2] + + result += lineseg_integrate(polygon, index, line_seg, expr, degree - 1) + result /= (degree + 1) + return result + + +def integration_reduction(facets, index, a, b, expr, dims, degree): + """Helper method for main_integrate. Returns the value of the input + expression evaluated over the polytope facet referenced by a given index. + + Parameters + =========== + + facets : + List of facets of the polytope. + index : + Index referencing the facet to integrate the expression over. + a : + Hyperplane parameter denoting direction. + b : + Hyperplane parameter denoting distance. + expr : + The expression to integrate over the facet. + dims : + List of symbols denoting axes. + degree : + Degree of the homogeneous polynomial. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy.integrals.intpoly import integration_reduction,\ + hyperplane_parameters + >>> from sympy import Point, Polygon + >>> triangle = Polygon(Point(0, 3), Point(5, 3), Point(1, 1)) + >>> facets = triangle.sides + >>> a, b = hyperplane_parameters(triangle)[0] + >>> integration_reduction(facets, 0, a, b, 1, (x, y), 0) + 5 + """ + expr = _sympify(expr) + if expr.is_zero: + return expr + + value = S.Zero + x0 = facets[index].points[0] + m = len(facets) + gens = (x, y) + + inner_product = diff(expr, gens[0]) * x0[0] + diff(expr, gens[1]) * x0[1] + + if inner_product != 0: + value += integration_reduction(facets, index, a, b, + inner_product, dims, degree - 1) + + value += left_integral2D(m, index, facets, x0, expr, gens) + + return value/(len(dims) + degree - 1) + + +def left_integral2D(m, index, facets, x0, expr, gens): + """Computes the left integral of Eq 10 in Chin et al. + For the 2D case, the integral is just an evaluation of the polynomial + at the intersection of two facets which is multiplied by the distance + between the first point of facet and that intersection. + + Parameters + ========== + + m : + No. of hyperplanes. + index : + Index of facet to find intersections with. + facets : + List of facets(Line Segments in 2D case). + x0 : + First point on facet referenced by index. + expr : + Input polynomial + gens : + Generators which generate the polynomial + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy.integrals.intpoly import left_integral2D + >>> from sympy import Point, Polygon + >>> triangle = Polygon(Point(0, 3), Point(5, 3), Point(1, 1)) + >>> facets = triangle.sides + >>> left_integral2D(3, 0, facets, facets[0].points[0], 1, (x, y)) + 5 + """ + value = S.Zero + for j in range(m): + intersect = () + if j in ((index - 1) % m, (index + 1) % m): + intersect = intersection(facets[index], facets[j], "segment2D") + if intersect: + distance_origin = norm(tuple(map(lambda x, y: x - y, + intersect, x0))) + if is_vertex(intersect): + if isinstance(expr, Expr): + if len(gens) == 3: + expr_dict = {gens[0]: intersect[0], + gens[1]: intersect[1], + gens[2]: intersect[2]} + else: + expr_dict = {gens[0]: intersect[0], + gens[1]: intersect[1]} + value += distance_origin * expr.subs(expr_dict) + else: + value += distance_origin * expr + return value + + +def integration_reduction_dynamic(facets, index, a, b, expr, degree, dims, + x_index, y_index, max_index, x0, + monomial_values, monom_index, vertices=None, + hp_param=None): + """The same integration_reduction function which uses a dynamic + programming approach to compute terms by using the values of the integral + of previously computed terms. + + Parameters + ========== + + facets : + Facets of the Polytope. + index : + Index of facet to find intersections with.(Used in left_integral()). + a, b : + Hyperplane parameters. + expr : + Input monomial. + degree : + Total degree of ``expr``. + dims : + Tuple denoting axes variables. + x_index : + Exponent of 'x' in ``expr``. + y_index : + Exponent of 'y' in ``expr``. + max_index : + Maximum exponent of any monomial in ``monomial_values``. + x0 : + First point on ``facets[index]``. + monomial_values : + List of monomial values constituting the polynomial. + monom_index : + Index of monomial whose integration is being found. + vertices : optional + Coordinates of vertices constituting the 3-Polytope. + hp_param : optional + Hyperplane Parameter of the face of the facets[index]. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy.integrals.intpoly import (integration_reduction_dynamic, \ + hyperplane_parameters) + >>> from sympy import Point, Polygon + >>> triangle = Polygon(Point(0, 3), Point(5, 3), Point(1, 1)) + >>> facets = triangle.sides + >>> a, b = hyperplane_parameters(triangle)[0] + >>> x0 = facets[0].points[0] + >>> monomial_values = [[0, 0, 0, 0], [1, 0, 0, 5],\ + [y, 0, 1, 15], [x, 1, 0, None]] + >>> integration_reduction_dynamic(facets, 0, a, b, x, 1, (x, y), 1, 0, 1,\ + x0, monomial_values, 3) + 25/2 + """ + value = S.Zero + m = len(facets) + + if expr == S.Zero: + return expr + + if len(dims) == 2: + if not expr.is_number: + _, x_degree, y_degree, _ = monomial_values[monom_index] + x_index = monom_index - max_index + \ + x_index - 2 if x_degree > 0 else 0 + y_index = monom_index - 1 if y_degree > 0 else 0 + x_value, y_value =\ + monomial_values[x_index][3], monomial_values[y_index][3] + + value += x_degree * x_value * x0[0] + y_degree * y_value * x0[1] + + value += left_integral2D(m, index, facets, x0, expr, dims) + else: + # For 3D use case the max_index contains the z_degree of the term + z_index = max_index + if not expr.is_number: + x_degree, y_degree, z_degree = y_index,\ + z_index - x_index - y_index, x_index + x_value = monomial_values[z_index - 1][y_index - 1][x_index][7]\ + if x_degree > 0 else 0 + y_value = monomial_values[z_index - 1][y_index][x_index][7]\ + if y_degree > 0 else 0 + z_value = monomial_values[z_index - 1][y_index][x_index - 1][7]\ + if z_degree > 0 else 0 + + value += x_degree * x_value * x0[0] + y_degree * y_value * x0[1] \ + + z_degree * z_value * x0[2] + + value += left_integral3D(facets, index, expr, + vertices, hp_param, degree) + return value / (len(dims) + degree - 1) + + +def left_integral3D(facets, index, expr, vertices, hp_param, degree): + """Computes the left integral of Eq 10 in Chin et al. + + Explanation + =========== + + For the 3D case, this is the sum of the integral values over constituting + line segments of the face (which is accessed by facets[index]) multiplied + by the distance between the first point of facet and that line segment. + + Parameters + ========== + + facets : + List of faces of the 3-Polytope. + index : + Index of face over which integral is to be calculated. + expr : + Input polynomial. + vertices : + List of vertices that constitute the 3-Polytope. + hp_param : + The hyperplane parameters of the face. + degree : + Degree of the ``expr``. + + Examples + ======== + + >>> from sympy.integrals.intpoly import left_integral3D + >>> cube = [[(0, 0, 0), (0, 0, 5), (0, 5, 0), (0, 5, 5), (5, 0, 0),\ + (5, 0, 5), (5, 5, 0), (5, 5, 5)],\ + [2, 6, 7, 3], [3, 7, 5, 1], [7, 6, 4, 5], [1, 5, 4, 0],\ + [3, 1, 0, 2], [0, 4, 6, 2]] + >>> facets = cube[1:] + >>> vertices = cube[0] + >>> left_integral3D(facets, 3, 1, vertices, ([0, -1, 0], -5), 0) + -50 + """ + value = S.Zero + facet = facets[index] + x0 = vertices[facet[0]] + facet_len = len(facet) + for i, fac in enumerate(facet): + side = (vertices[fac], vertices[facet[(i + 1) % facet_len]]) + value += distance_to_side(x0, side, hp_param[0]) * \ + lineseg_integrate(facet, i, side, expr, degree) + return value + + +def gradient_terms(binomial_power=0, no_of_gens=2): + """Returns a list of all the possible monomials between + 0 and y**binomial_power for 2D case and z**binomial_power + for 3D case. + + Parameters + ========== + + binomial_power : + Power upto which terms are generated. + no_of_gens : + Denotes whether terms are being generated for 2D or 3D case. + + Examples + ======== + + >>> from sympy.integrals.intpoly import gradient_terms + >>> gradient_terms(2) + [[1, 0, 0, 0], [y, 0, 1, 0], [y**2, 0, 2, 0], [x, 1, 0, 0], + [x*y, 1, 1, 0], [x**2, 2, 0, 0]] + >>> gradient_terms(2, 3) + [[[[1, 0, 0, 0, 0, 0, 0, 0]]], [[[y, 0, 1, 0, 1, 0, 0, 0], + [z, 0, 0, 1, 1, 0, 1, 0]], [[x, 1, 0, 0, 1, 1, 0, 0]]], + [[[y**2, 0, 2, 0, 2, 0, 0, 0], [y*z, 0, 1, 1, 2, 0, 1, 0], + [z**2, 0, 0, 2, 2, 0, 2, 0]], [[x*y, 1, 1, 0, 2, 1, 0, 0], + [x*z, 1, 0, 1, 2, 1, 1, 0]], [[x**2, 2, 0, 0, 2, 2, 0, 0]]]] + """ + if no_of_gens == 2: + count = 0 + terms = [None] * int((binomial_power ** 2 + 3 * binomial_power + 2) / 2) + for x_count in range(0, binomial_power + 1): + for y_count in range(0, binomial_power - x_count + 1): + terms[count] = [x**x_count*y**y_count, + x_count, y_count, 0] + count += 1 + else: + terms = [[[[x ** x_count * y ** y_count * + z ** (z_count - y_count - x_count), + x_count, y_count, z_count - y_count - x_count, + z_count, x_count, z_count - y_count - x_count, 0] + for y_count in range(z_count - x_count, -1, -1)] + for x_count in range(0, z_count + 1)] + for z_count in range(0, binomial_power + 1)] + return terms + + +def hyperplane_parameters(poly, vertices=None): + """A helper function to return the hyperplane parameters + of which the facets of the polytope are a part of. + + Parameters + ========== + + poly : + The input 2/3-Polytope. + vertices : + Vertex indices of 3-Polytope. + + Examples + ======== + + >>> from sympy import Point, Polygon + >>> from sympy.integrals.intpoly import hyperplane_parameters + >>> hyperplane_parameters(Polygon(Point(0, 3), Point(5, 3), Point(1, 1))) + [((0, 1), 3), ((1, -2), -1), ((-2, -1), -3)] + >>> cube = [[(0, 0, 0), (0, 0, 5), (0, 5, 0), (0, 5, 5), (5, 0, 0),\ + (5, 0, 5), (5, 5, 0), (5, 5, 5)],\ + [2, 6, 7, 3], [3, 7, 5, 1], [7, 6, 4, 5], [1, 5, 4, 0],\ + [3, 1, 0, 2], [0, 4, 6, 2]] + >>> hyperplane_parameters(cube[1:], cube[0]) + [([0, -1, 0], -5), ([0, 0, -1], -5), ([-1, 0, 0], -5), + ([0, 1, 0], 0), ([1, 0, 0], 0), ([0, 0, 1], 0)] + """ + if isinstance(poly, Polygon): + vertices = list(poly.vertices) + [poly.vertices[0]] # Close the polygon + params = [None] * (len(vertices) - 1) + + for i in range(len(vertices) - 1): + v1 = vertices[i] + v2 = vertices[i + 1] + + a1 = v1[1] - v2[1] + a2 = v2[0] - v1[0] + b = v2[0] * v1[1] - v2[1] * v1[0] + + factor = gcd_list([a1, a2, b]) + + b = S(b) / factor + a = (S(a1) / factor, S(a2) / factor) + params[i] = (a, b) + else: + params = [None] * len(poly) + for i, polygon in enumerate(poly): + v1, v2, v3 = [vertices[vertex] for vertex in polygon[:3]] + normal = cross_product(v1, v2, v3) + b = sum(normal[j] * v1[j] for j in range(0, 3)) + fac = gcd_list(normal) + if fac.is_zero: + fac = 1 + normal = [j / fac for j in normal] + b = b / fac + params[i] = (normal, b) + return params + + +def cross_product(v1, v2, v3): + """Returns the cross-product of vectors (v2 - v1) and (v3 - v1) + That is : (v2 - v1) X (v3 - v1) + """ + v2 = [v2[j] - v1[j] for j in range(0, 3)] + v3 = [v3[j] - v1[j] for j in range(0, 3)] + return [v3[2] * v2[1] - v3[1] * v2[2], + v3[0] * v2[2] - v3[2] * v2[0], + v3[1] * v2[0] - v3[0] * v2[1]] + + +def best_origin(a, b, lineseg, expr): + """Helper method for polytope_integrate. Currently not used in the main + algorithm. + + Explanation + =========== + + Returns a point on the lineseg whose vector inner product with the + divergence of `expr` yields an expression with the least maximum + total power. + + Parameters + ========== + + a : + Hyperplane parameter denoting direction. + b : + Hyperplane parameter denoting distance. + lineseg : + Line segment on which to find the origin. + expr : + The expression which determines the best point. + + Algorithm(currently works only for 2D use case) + =============================================== + + 1 > Firstly, check for edge cases. Here that would refer to vertical + or horizontal lines. + + 2 > If input expression is a polynomial containing more than one generator + then find out the total power of each of the generators. + + x**2 + 3 + x*y + x**4*y**5 ---> {x: 7, y: 6} + + If expression is a constant value then pick the first boundary point + of the line segment. + + 3 > First check if a point exists on the line segment where the value of + the highest power generator becomes 0. If not check if the value of + the next highest becomes 0. If none becomes 0 within line segment + constraints then pick the first boundary point of the line segment. + Actually, any point lying on the segment can be picked as best origin + in the last case. + + Examples + ======== + + >>> from sympy.integrals.intpoly import best_origin + >>> from sympy.abc import x, y + >>> from sympy import Point, Segment2D + >>> l = Segment2D(Point(0, 3), Point(1, 1)) + >>> expr = x**3*y**7 + >>> best_origin((2, 1), 3, l, expr) + (0, 3.0) + """ + a1, b1 = lineseg.points[0] + + def x_axis_cut(ls): + """Returns the point where the input line segment + intersects the x-axis. + + Parameters + ========== + + ls : + Line segment + """ + p, q = ls.points + if p.y.is_zero: + return tuple(p) + elif q.y.is_zero: + return tuple(q) + elif p.y/q.y < S.Zero: + return p.y * (p.x - q.x)/(q.y - p.y) + p.x, S.Zero + else: + return () + + def y_axis_cut(ls): + """Returns the point where the input line segment + intersects the y-axis. + + Parameters + ========== + + ls : + Line segment + """ + p, q = ls.points + if p.x.is_zero: + return tuple(p) + elif q.x.is_zero: + return tuple(q) + elif p.x/q.x < S.Zero: + return S.Zero, p.x * (p.y - q.y)/(q.x - p.x) + p.y + else: + return () + + gens = (x, y) + power_gens = {} + + for i in gens: + power_gens[i] = S.Zero + + if len(gens) > 1: + # Special case for vertical and horizontal lines + if len(gens) == 2: + if a[0] == 0: + if y_axis_cut(lineseg): + return S.Zero, b/a[1] + else: + return a1, b1 + elif a[1] == 0: + if x_axis_cut(lineseg): + return b/a[0], S.Zero + else: + return a1, b1 + + if isinstance(expr, Expr): # Find the sum total of power of each + if expr.is_Add: # generator and store in a dictionary. + for monomial in expr.args: + if monomial.is_Pow: + if monomial.args[0] in gens: + power_gens[monomial.args[0]] += monomial.args[1] + else: + for univariate in monomial.args: + term_type = len(univariate.args) + if term_type == 0 and univariate in gens: + power_gens[univariate] += 1 + elif term_type == 2 and univariate.args[0] in gens: + power_gens[univariate.args[0]] +=\ + univariate.args[1] + elif expr.is_Mul: + for term in expr.args: + term_type = len(term.args) + if term_type == 0 and term in gens: + power_gens[term] += 1 + elif term_type == 2 and term.args[0] in gens: + power_gens[term.args[0]] += term.args[1] + elif expr.is_Pow: + power_gens[expr.args[0]] = expr.args[1] + elif expr.is_Symbol: + power_gens[expr] += 1 + else: # If `expr` is a constant take first vertex of the line segment. + return a1, b1 + + # TODO : This part is quite hacky. Should be made more robust with + # TODO : respect to symbol names and scalable w.r.t higher dimensions. + power_gens = sorted(power_gens.items(), key=lambda k: str(k[0])) + if power_gens[0][1] >= power_gens[1][1]: + if y_axis_cut(lineseg): + x0 = (S.Zero, b / a[1]) + elif x_axis_cut(lineseg): + x0 = (b / a[0], S.Zero) + else: + x0 = (a1, b1) + else: + if x_axis_cut(lineseg): + x0 = (b/a[0], S.Zero) + elif y_axis_cut(lineseg): + x0 = (S.Zero, b/a[1]) + else: + x0 = (a1, b1) + else: + x0 = (b/a[0]) + return x0 + + +def decompose(expr, separate=False): + """Decomposes an input polynomial into homogeneous ones of + smaller or equal degree. + + Explanation + =========== + + Returns a dictionary with keys as the degree of the smaller + constituting polynomials. Values are the constituting polynomials. + + Parameters + ========== + + expr : Expr + Polynomial(SymPy expression). + separate : bool + If True then simply return a list of the constituent monomials + If not then break up the polynomial into constituent homogeneous + polynomials. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy.integrals.intpoly import decompose + >>> decompose(x**2 + x*y + x + y + x**3*y**2 + y**5) + {1: x + y, 2: x**2 + x*y, 5: x**3*y**2 + y**5} + >>> decompose(x**2 + x*y + x + y + x**3*y**2 + y**5, True) + {x, x**2, y, y**5, x*y, x**3*y**2} + """ + poly_dict = {} + + if isinstance(expr, Expr) and not expr.is_number: + if expr.is_Symbol: + poly_dict[1] = expr + elif expr.is_Add: + symbols = expr.atoms(Symbol) + degrees = [(sum(degree_list(monom, *symbols)), monom) + for monom in expr.args] + if separate: + return {monom[1] for monom in degrees} + else: + for monom in degrees: + degree, term = monom + if poly_dict.get(degree): + poly_dict[degree] += term + else: + poly_dict[degree] = term + elif expr.is_Pow: + _, degree = expr.args + poly_dict[degree] = expr + else: # Now expr can only be of `Mul` type + degree = 0 + for term in expr.args: + term_type = len(term.args) + if term_type == 0 and term.is_Symbol: + degree += 1 + elif term_type == 2: + degree += term.args[1] + poly_dict[degree] = expr + else: + poly_dict[0] = expr + + if separate: + return set(poly_dict.values()) + return poly_dict + + +def point_sort(poly, normal=None, clockwise=True): + """Returns the same polygon with points sorted in clockwise or + anti-clockwise order. + + Note that it's necessary for input points to be sorted in some order + (clockwise or anti-clockwise) for the integration algorithm to work. + As a convention algorithm has been implemented keeping clockwise + orientation in mind. + + Parameters + ========== + + poly: + 2D or 3D Polygon. + normal : optional + The normal of the plane which the 3-Polytope is a part of. + clockwise : bool, optional + Returns points sorted in clockwise order if True and + anti-clockwise if False. + + Examples + ======== + + >>> from sympy.integrals.intpoly import point_sort + >>> from sympy import Point + >>> point_sort([Point(0, 0), Point(1, 0), Point(1, 1)]) + [Point2D(1, 1), Point2D(1, 0), Point2D(0, 0)] + """ + pts = poly.vertices if isinstance(poly, Polygon) else poly + n = len(pts) + if n < 2: + return list(pts) + + order = S.One if clockwise else S.NegativeOne + dim = len(pts[0]) + if dim == 2: + center = Point(sum((vertex.x for vertex in pts)) / n, + sum((vertex.y for vertex in pts)) / n) + else: + center = Point(sum((vertex.x for vertex in pts)) / n, + sum((vertex.y for vertex in pts)) / n, + sum((vertex.z for vertex in pts)) / n) + + def compare(a, b): + if a.x - center.x >= S.Zero and b.x - center.x < S.Zero: + return -order + elif a.x - center.x < 0 and b.x - center.x >= 0: + return order + elif a.x - center.x == 0 and b.x - center.x == 0: + if a.y - center.y >= 0 or b.y - center.y >= 0: + return -order if a.y > b.y else order + return -order if b.y > a.y else order + + det = (a.x - center.x) * (b.y - center.y) -\ + (b.x - center.x) * (a.y - center.y) + if det < 0: + return -order + elif det > 0: + return order + + first = (a.x - center.x) * (a.x - center.x) +\ + (a.y - center.y) * (a.y - center.y) + second = (b.x - center.x) * (b.x - center.x) +\ + (b.y - center.y) * (b.y - center.y) + return -order if first > second else order + + def compare3d(a, b): + det = cross_product(center, a, b) + dot_product = sum(det[i] * normal[i] for i in range(0, 3)) + if dot_product < 0: + return -order + elif dot_product > 0: + return order + + return sorted(pts, key=cmp_to_key(compare if dim==2 else compare3d)) + + +def norm(point): + """Returns the Euclidean norm of a point from origin. + + Parameters + ========== + + point: + This denotes a point in the dimension_al spac_e. + + Examples + ======== + + >>> from sympy.integrals.intpoly import norm + >>> from sympy import Point + >>> norm(Point(2, 7)) + sqrt(53) + """ + half = S.Half + if isinstance(point, (list, tuple)): + return sum(coord ** 2 for coord in point) ** half + elif isinstance(point, Point): + if isinstance(point, Point2D): + return (point.x ** 2 + point.y ** 2) ** half + else: + return (point.x ** 2 + point.y ** 2 + point.z) ** half + elif isinstance(point, dict): + return sum(i**2 for i in point.values()) ** half + + +def intersection(geom_1, geom_2, intersection_type): + """Returns intersection between geometric objects. + + Explanation + =========== + + Note that this function is meant for use in integration_reduction and + at that point in the calling function the lines denoted by the segments + surely intersect within segment boundaries. Coincident lines are taken + to be non-intersecting. Also, the hyperplane intersection for 2D case is + also implemented. + + Parameters + ========== + + geom_1, geom_2: + The input line segments. + + Examples + ======== + + >>> from sympy.integrals.intpoly import intersection + >>> from sympy import Point, Segment2D + >>> l1 = Segment2D(Point(1, 1), Point(3, 5)) + >>> l2 = Segment2D(Point(2, 0), Point(2, 5)) + >>> intersection(l1, l2, "segment2D") + (2, 3) + >>> p1 = ((-1, 0), 0) + >>> p2 = ((0, 1), 1) + >>> intersection(p1, p2, "plane2D") + (0, 1) + """ + if intersection_type[:-2] == "segment": + if intersection_type == "segment2D": + x1, y1 = geom_1.points[0] + x2, y2 = geom_1.points[1] + x3, y3 = geom_2.points[0] + x4, y4 = geom_2.points[1] + elif intersection_type == "segment3D": + x1, y1, z1 = geom_1.points[0] + x2, y2, z2 = geom_1.points[1] + x3, y3, z3 = geom_2.points[0] + x4, y4, z4 = geom_2.points[1] + + denom = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4) + if denom: + t1 = x1 * y2 - y1 * x2 + t2 = x3 * y4 - x4 * y3 + return (S(t1 * (x3 - x4) - t2 * (x1 - x2)) / denom, + S(t1 * (y3 - y4) - t2 * (y1 - y2)) / denom) + if intersection_type[:-2] == "plane": + if intersection_type == "plane2D": # Intersection of hyperplanes + a1x, a1y = geom_1[0] + a2x, a2y = geom_2[0] + b1, b2 = geom_1[1], geom_2[1] + + denom = a1x * a2y - a2x * a1y + if denom: + return (S(b1 * a2y - b2 * a1y) / denom, + S(b2 * a1x - b1 * a2x) / denom) + + +def is_vertex(ent): + """If the input entity is a vertex return True. + + Parameter + ========= + + ent : + Denotes a geometric entity representing a point. + + Examples + ======== + + >>> from sympy import Point + >>> from sympy.integrals.intpoly import is_vertex + >>> is_vertex((2, 3)) + True + >>> is_vertex((2, 3, 6)) + True + >>> is_vertex(Point(2, 3)) + True + """ + if isinstance(ent, tuple): + if len(ent) in [2, 3]: + return True + elif isinstance(ent, Point): + return True + return False + + +def plot_polytope(poly): + """Plots the 2D polytope using the functions written in plotting + module which in turn uses matplotlib backend. + + Parameter + ========= + + poly: + Denotes a 2-Polytope. + """ + from sympy.plotting.plot import Plot, List2DSeries + + xl = [vertex.x for vertex in poly.vertices] + yl = [vertex.y for vertex in poly.vertices] + + xl.append(poly.vertices[0].x) # Closing the polygon + yl.append(poly.vertices[0].y) + + l2ds = List2DSeries(xl, yl) + p = Plot(l2ds, axes='label_axes=True') + p.show() + + +def plot_polynomial(expr): + """Plots the polynomial using the functions written in + plotting module which in turn uses matplotlib backend. + + Parameter + ========= + + expr: + Denotes a polynomial(SymPy expression). + """ + from sympy.plotting.plot import plot3d, plot + gens = expr.free_symbols + if len(gens) == 2: + plot3d(expr) + else: + plot(expr) diff --git a/.venv/lib/python3.13/site-packages/sympy/integrals/laplace.py b/.venv/lib/python3.13/site-packages/sympy/integrals/laplace.py new file mode 100644 index 0000000000000000000000000000000000000000..604c4a2711440f13c62f0d778e382e4daaf33148 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/integrals/laplace.py @@ -0,0 +1,2377 @@ +"""Laplace Transforms""" +import sys +import sympy +from sympy.core import S, pi, I +from sympy.core.add import Add +from sympy.core.cache import cacheit +from sympy.core.expr import Expr +from sympy.core.function import ( + AppliedUndef, Derivative, expand, expand_complex, expand_mul, expand_trig, + Lambda, WildFunction, diff, Subs) +from sympy.core.mul import Mul, prod +from sympy.core.relational import ( + _canonical, Ge, Gt, Lt, Unequality, Eq, Ne, Relational) +from sympy.core.sorting import ordered +from sympy.core.symbol import Dummy, symbols, Wild +from sympy.functions.elementary.complexes import ( + re, im, arg, Abs, polar_lift, periodic_argument) +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.hyperbolic import cosh, coth, sinh, asinh +from sympy.functions.elementary.miscellaneous import Max, Min, sqrt +from sympy.functions.elementary.piecewise import ( + Piecewise, piecewise_exclusive) +from sympy.functions.elementary.trigonometric import cos, sin, atan, sinc +from sympy.functions.special.bessel import besseli, besselj, besselk, bessely +from sympy.functions.special.delta_functions import DiracDelta, Heaviside +from sympy.functions.special.error_functions import erf, erfc, Ei +from sympy.functions.special.gamma_functions import ( + digamma, gamma, lowergamma, uppergamma) +from sympy.functions.special.singularity_functions import SingularityFunction +from sympy.integrals import integrate, Integral +from sympy.integrals.transforms import ( + _simplify, IntegralTransform, IntegralTransformError) +from sympy.logic.boolalg import to_cnf, conjuncts, disjuncts, Or, And +from sympy.matrices.matrixbase import MatrixBase +from sympy.polys.matrices.linsolve import _lin_eq2dict +from sympy.polys.polyerrors import PolynomialError +from sympy.polys.polyroots import roots +from sympy.polys.polytools import Poly +from sympy.polys.rationaltools import together +from sympy.polys.rootoftools import RootSum +from sympy.utilities.exceptions import ( + sympy_deprecation_warning, SymPyDeprecationWarning, ignore_warnings) +from sympy.utilities.misc import debugf + +_LT_level = 0 + + +def DEBUG_WRAP(func): + def wrap(*args, **kwargs): + from sympy import SYMPY_DEBUG + global _LT_level + + if not SYMPY_DEBUG: + return func(*args, **kwargs) + + if _LT_level == 0: + print('\n' + '-'*78, file=sys.stderr) + print('-LT- %s%s%s' % (' '*_LT_level, func.__name__, args), + file=sys.stderr) + _LT_level += 1 + if ( + func.__name__ == '_laplace_transform_integration' or + func.__name__ == '_inverse_laplace_transform_integration'): + sympy.SYMPY_DEBUG = False + print('**** %sIntegrating ...' % (' '*_LT_level), file=sys.stderr) + result = func(*args, **kwargs) + sympy.SYMPY_DEBUG = True + else: + result = func(*args, **kwargs) + _LT_level -= 1 + print('-LT- %s---> %s' % (' '*_LT_level, result), file=sys.stderr) + if _LT_level == 0: + print('-'*78 + '\n', file=sys.stderr) + return result + return wrap + + +def _debug(text): + from sympy import SYMPY_DEBUG + + if SYMPY_DEBUG: + print('-LT- %s%s' % (' '*_LT_level, text), file=sys.stderr) + + +def _simplifyconds(expr, s, a): + r""" + Naively simplify some conditions occurring in ``expr``, + given that `\operatorname{Re}(s) > a`. + + Examples + ======== + + >>> from sympy.integrals.laplace import _simplifyconds + >>> from sympy.abc import x + >>> from sympy import sympify as S + >>> _simplifyconds(abs(x**2) < 1, x, 1) + False + >>> _simplifyconds(abs(x**2) < 1, x, 2) + False + >>> _simplifyconds(abs(x**2) < 1, x, 0) + Abs(x**2) < 1 + >>> _simplifyconds(abs(1/x**2) < 1, x, 1) + True + >>> _simplifyconds(S(1) < abs(x), x, 1) + True + >>> _simplifyconds(S(1) < abs(1/x), x, 1) + False + + >>> from sympy import Ne + >>> _simplifyconds(Ne(1, x**3), x, 1) + True + >>> _simplifyconds(Ne(1, x**3), x, 2) + True + >>> _simplifyconds(Ne(1, x**3), x, 0) + Ne(1, x**3) + """ + + def power(ex): + if ex == s: + return 1 + if ex.is_Pow and ex.base == s: + return ex.exp + return None + + def bigger(ex1, ex2): + """ Return True only if |ex1| > |ex2|, False only if |ex1| < |ex2|. + Else return None. """ + if ex1.has(s) and ex2.has(s): + return None + if isinstance(ex1, Abs): + ex1 = ex1.args[0] + if isinstance(ex2, Abs): + ex2 = ex2.args[0] + if ex1.has(s): + return bigger(1/ex2, 1/ex1) + n = power(ex2) + if n is None: + return None + try: + if n > 0 and (Abs(ex1) <= Abs(a)**n) == S.true: + return False + if n < 0 and (Abs(ex1) >= Abs(a)**n) == S.true: + return True + except TypeError: + return None + + def replie(x, y): + """ simplify x < y """ + if (not (x.is_positive or isinstance(x, Abs)) + or not (y.is_positive or isinstance(y, Abs))): + return (x < y) + r = bigger(x, y) + if r is not None: + return not r + return (x < y) + + def replue(x, y): + b = bigger(x, y) + if b in (True, False): + return True + return Unequality(x, y) + + def repl(ex, *args): + if ex in (True, False): + return bool(ex) + return ex.replace(*args) + + from sympy.simplify.radsimp import collect_abs + expr = collect_abs(expr) + expr = repl(expr, Lt, replie) + expr = repl(expr, Gt, lambda x, y: replie(y, x)) + expr = repl(expr, Unequality, replue) + return S(expr) + + +@DEBUG_WRAP +def expand_dirac_delta(expr): + """ + Expand an expression involving DiractDelta to get it as a linear + combination of DiracDelta functions. + """ + return _lin_eq2dict(expr, expr.atoms(DiracDelta)) + + +@DEBUG_WRAP +def _laplace_transform_integration(f, t, s_, *, simplify): + """ The backend function for doing Laplace transforms by integration. + + This backend assumes that the frontend has already split sums + such that `f` is to an addition anymore. + """ + s = Dummy('s') + + if f.has(DiracDelta): + return None + + F = integrate(f*exp(-s*t), (t, S.Zero, S.Infinity)) + + if not F.has(Integral): + return _simplify(F.subs(s, s_), simplify), S.NegativeInfinity, S.true + + if not F.is_Piecewise: + return None + + F, cond = F.args[0] + if F.has(Integral): + return None + + def process_conds(conds): + """ Turn ``conds`` into a strip and auxiliary conditions. """ + from sympy.solvers.inequalities import _solve_inequality + a = S.NegativeInfinity + aux = S.true + conds = conjuncts(to_cnf(conds)) + p, q, w1, w2, w3, w4, w5 = symbols( + 'p q w1 w2 w3 w4 w5', cls=Wild, exclude=[s]) + patterns = ( + p*Abs(arg((s + w3)*q)) < w2, + p*Abs(arg((s + w3)*q)) <= w2, + Abs(periodic_argument((s + w3)**p*q, w1)) < w2, + Abs(periodic_argument((s + w3)**p*q, w1)) <= w2, + Abs(periodic_argument((polar_lift(s + w3))**p*q, w1)) < w2, + Abs(periodic_argument((polar_lift(s + w3))**p*q, w1)) <= w2) + for c in conds: + a_ = S.Infinity + aux_ = [] + for d in disjuncts(c): + if d.is_Relational and s in d.rhs.free_symbols: + d = d.reversed + if d.is_Relational and isinstance(d, (Ge, Gt)): + d = d.reversedsign + for pat in patterns: + m = d.match(pat) + if m: + break + if m and m[q].is_positive and m[w2]/m[p] == pi/2: + d = -re(s + m[w3]) < 0 + m = d.match(p - cos(w1*Abs(arg(s*w5))*w2)*Abs(s**w3)**w4 < 0) + if not m: + m = d.match( + cos(p - Abs(periodic_argument(s**w1*w5, q))*w2) * + Abs(s**w3)**w4 < 0) + if not m: + m = d.match( + p - cos( + Abs(periodic_argument(polar_lift(s)**w1*w5, q))*w2 + )*Abs(s**w3)**w4 < 0) + if m and all(m[wild].is_positive for wild in [ + w1, w2, w3, w4, w5]): + d = re(s) > m[p] + d_ = d.replace( + re, lambda x: x.expand().as_real_imag()[0]).subs(re(s), t) + if ( + not d.is_Relational or d.rel_op in ('==', '!=') + or d_.has(s) or not d_.has(t)): + aux_ += [d] + continue + soln = _solve_inequality(d_, t) + if not soln.is_Relational or soln.rel_op in ('==', '!='): + aux_ += [d] + continue + if soln.lts == t: + return None + else: + a_ = Min(soln.lts, a_) + if a_ is not S.Infinity: + a = Max(a_, a) + else: + aux = And(aux, Or(*aux_)) + return a, aux.canonical if aux.is_Relational else aux + + conds = [process_conds(c) for c in disjuncts(cond)] + conds2 = [x for x in conds if x[1] != + S.false and x[0] is not S.NegativeInfinity] + if not conds2: + conds2 = [x for x in conds if x[1] != S.false] + conds = list(ordered(conds2)) + + def cnt(expr): + if expr in (True, False): + return 0 + return expr.count_ops() + conds.sort(key=lambda x: (-x[0], cnt(x[1]))) + + if not conds: + return None + a, aux = conds[0] # XXX is [0] always the right one? + + def sbs(expr): + return expr.subs(s, s_) + if simplify: + F = _simplifyconds(F, s, a) + aux = _simplifyconds(aux, s, a) + return _simplify(F.subs(s, s_), simplify), sbs(a), _canonical(sbs(aux)) + + +@DEBUG_WRAP +def _laplace_deep_collect(f, t): + """ + This is an internal helper function that traverses through the expression + tree of `f(t)` and collects arguments. The purpose of it is that + anything like `f(w*t-1*t-c)` will be written as `f((w-1)*t-c)` such that + it can match `f(a*t+b)`. + """ + if not isinstance(f, Expr): + return f + if (p := f.as_poly(t)) is not None: + return p.as_expr() + func = f.func + args = [_laplace_deep_collect(arg, t) for arg in f.args] + return func(*args) + + +@cacheit +def _laplace_build_rules(): + """ + This is an internal helper function that returns the table of Laplace + transform rules in terms of the time variable `t` and the frequency + variable `s`. It is used by ``_laplace_apply_rules``. Each entry is a + tuple containing: + + (time domain pattern, + frequency-domain replacement, + condition for the rule to be applied, + convergence plane, + preparation function) + + The preparation function is a function with one argument that is applied + to the expression before matching. For most rules it should be + ``_laplace_deep_collect``. + """ + t = Dummy('t') + s = Dummy('s') + a = Wild('a', exclude=[t]) + b = Wild('b', exclude=[t]) + n = Wild('n', exclude=[t]) + tau = Wild('tau', exclude=[t]) + omega = Wild('omega', exclude=[t]) + def dco(f): return _laplace_deep_collect(f, t) + _debug('_laplace_build_rules is building rules') + + laplace_transform_rules = [ + (a, a/s, + S.true, S.Zero, dco), # 4.2.1 + (DiracDelta(a*t-b), exp(-s*b/a)/Abs(a), + Or(And(a > 0, b >= 0), And(a < 0, b <= 0)), + S.NegativeInfinity, dco), # Not in Bateman54 + (DiracDelta(a*t-b), S(0), + Or(And(a < 0, b >= 0), And(a > 0, b <= 0)), + S.NegativeInfinity, dco), # Not in Bateman54 + (Heaviside(a*t-b), exp(-s*b/a)/s, + And(a > 0, b > 0), S.Zero, dco), # 4.4.1 + (Heaviside(a*t-b), (1-exp(-s*b/a))/s, + And(a < 0, b < 0), S.Zero, dco), # 4.4.1 + (Heaviside(a*t-b), 1/s, + And(a > 0, b <= 0), S.Zero, dco), # 4.4.1 + (Heaviside(a*t-b), 0, + And(a < 0, b > 0), S.Zero, dco), # 4.4.1 + (t, 1/s**2, + S.true, S.Zero, dco), # 4.2.3 + (1/(a*t+b), -exp(-b/a*s)*Ei(-b/a*s)/a, + Abs(arg(b/a)) < pi, S.Zero, dco), # 4.2.6 + (1/sqrt(a*t+b), sqrt(a*pi/s)*exp(b/a*s)*erfc(sqrt(b/a*s))/a, + Abs(arg(b/a)) < pi, S.Zero, dco), # 4.2.18 + ((a*t+b)**(-S(3)/2), + 2*b**(-S(1)/2)-2*(pi*s/a)**(S(1)/2)*exp(b/a*s) * erfc(sqrt(b/a*s))/a, + Abs(arg(b/a)) < pi, S.Zero, dco), # 4.2.20 + (sqrt(t)/(t+b), sqrt(pi/s)-pi*sqrt(b)*exp(b*s)*erfc(sqrt(b*s)), + Abs(arg(b)) < pi, S.Zero, dco), # 4.2.22 + (1/(a*sqrt(t) + t**(3/2)), pi*a**(S(1)/2)*exp(a*s)*erfc(sqrt(a*s)), + S.true, S.Zero, dco), # Not in Bateman54 + (t**n, gamma(n+1)/s**(n+1), + n > -1, S.Zero, dco), # 4.3.1 + ((a*t+b)**n, uppergamma(n+1, b/a*s)*exp(-b/a*s)/s**(n+1)/a, + And(n > -1, Abs(arg(b/a)) < pi), S.Zero, dco), # 4.3.4 + (t**n/(t+a), a**n*gamma(n+1)*uppergamma(-n, a*s), + And(n > -1, Abs(arg(a)) < pi), S.Zero, dco), # 4.3.7 + (exp(a*t-tau), exp(-tau)/(s-a), + S.true, re(a), dco), # 4.5.1 + (t*exp(a*t-tau), exp(-tau)/(s-a)**2, + S.true, re(a), dco), # 4.5.2 + (t**n*exp(a*t), gamma(n+1)/(s-a)**(n+1), + re(n) > -1, re(a), dco), # 4.5.3 + (exp(-a*t**2), sqrt(pi/4/a)*exp(s**2/4/a)*erfc(s/sqrt(4*a)), + re(a) > 0, S.Zero, dco), # 4.5.21 + (t*exp(-a*t**2), + 1/(2*a)-2/sqrt(pi)/(4*a)**(S(3)/2)*s*erfc(s/sqrt(4*a)), + re(a) > 0, S.Zero, dco), # 4.5.22 + (exp(-a/t), 2*sqrt(a/s)*besselk(1, 2*sqrt(a*s)), + re(a) >= 0, S.Zero, dco), # 4.5.25 + (sqrt(t)*exp(-a/t), + S(1)/2*sqrt(pi/s**3)*(1+2*sqrt(a*s))*exp(-2*sqrt(a*s)), + re(a) >= 0, S.Zero, dco), # 4.5.26 + (exp(-a/t)/sqrt(t), sqrt(pi/s)*exp(-2*sqrt(a*s)), + re(a) >= 0, S.Zero, dco), # 4.5.27 + (exp(-a/t)/(t*sqrt(t)), sqrt(pi/a)*exp(-2*sqrt(a*s)), + re(a) > 0, S.Zero, dco), # 4.5.28 + (t**n*exp(-a/t), 2*(a/s)**((n+1)/2)*besselk(n+1, 2*sqrt(a*s)), + re(a) > 0, S.Zero, dco), # 4.5.29 + # TODO: rules with sqrt(a*t) and sqrt(a/t) have stopped working after + # changes to as_base_exp + # (exp(-2*sqrt(a*t)), + # s**(-1)-sqrt(pi*a)*s**(-S(3)/2)*exp(a/s) * erfc(sqrt(a/s)), + # Abs(arg(a)) < pi, S.Zero, dco), # 4.5.31 + # (exp(-2*sqrt(a*t))/sqrt(t), (pi/s)**(S(1)/2)*exp(a/s)*erfc(sqrt(a/s)), + # Abs(arg(a)) < pi, S.Zero, dco), # 4.5.33 + (exp(-a*exp(-t)), a**(-s)*lowergamma(s, a), + S.true, S.Zero, dco), # 4.5.36 + (exp(-a*exp(t)), a**s*uppergamma(-s, a), + re(a) > 0, S.Zero, dco), # 4.5.37 + (log(a*t), -log(exp(S.EulerGamma)*s/a)/s, + a > 0, S.Zero, dco), # 4.6.1 + (log(1+a*t), -exp(s/a)/s*Ei(-s/a), + Abs(arg(a)) < pi, S.Zero, dco), # 4.6.4 + (log(a*t+b), (log(b)-exp(s/b/a)/s*a*Ei(-s/b))/s*a, + And(a > 0, Abs(arg(b)) < pi), S.Zero, dco), # 4.6.5 + (log(t)/sqrt(t), -sqrt(pi/s)*log(4*s*exp(S.EulerGamma)), + S.true, S.Zero, dco), # 4.6.9 + (t**n*log(t), gamma(n+1)*s**(-n-1)*(digamma(n+1)-log(s)), + re(n) > -1, S.Zero, dco), # 4.6.11 + (log(a*t)**2, (log(exp(S.EulerGamma)*s/a)**2+pi**2/6)/s, + a > 0, S.Zero, dco), # 4.6.13 + (sin(omega*t), omega/(s**2+omega**2), + S.true, Abs(im(omega)), dco), # 4,7,1 + (Abs(sin(omega*t)), omega/(s**2+omega**2)*coth(pi*s/2/omega), + omega > 0, S.Zero, dco), # 4.7.2 + (sin(omega*t)/t, atan(omega/s), + S.true, Abs(im(omega)), dco), # 4.7.16 + (sin(omega*t)**2/t, log(1+4*omega**2/s**2)/4, + S.true, 2*Abs(im(omega)), dco), # 4.7.17 + (sin(omega*t)**2/t**2, + omega*atan(2*omega/s)-s*log(1+4*omega**2/s**2)/4, + S.true, 2*Abs(im(omega)), dco), # 4.7.20 + # (sin(2*sqrt(a*t)), sqrt(pi*a)/s/sqrt(s)*exp(-a/s), + # S.true, S.Zero, dco), # 4.7.32 + # (sin(2*sqrt(a*t))/t, pi*erf(sqrt(a/s)), + # S.true, S.Zero, dco), # 4.7.34 + (cos(omega*t), s/(s**2+omega**2), + S.true, Abs(im(omega)), dco), # 4.7.43 + (cos(omega*t)**2, (s**2+2*omega**2)/(s**2+4*omega**2)/s, + S.true, 2*Abs(im(omega)), dco), # 4.7.45 + # (sqrt(t)*cos(2*sqrt(a*t)), sqrt(pi)/2*s**(-S(5)/2)*(s-2*a)*exp(-a/s), + # S.true, S.Zero, dco), # 4.7.66 + # (cos(2*sqrt(a*t))/sqrt(t), sqrt(pi/s)*exp(-a/s), + # S.true, S.Zero, dco), # 4.7.67 + (sin(a*t)*sin(b*t), 2*a*b*s/(s**2+(a+b)**2)/(s**2+(a-b)**2), + S.true, Abs(im(a))+Abs(im(b)), dco), # 4.7.78 + (cos(a*t)*sin(b*t), b*(s**2-a**2+b**2)/(s**2+(a+b)**2)/(s**2+(a-b)**2), + S.true, Abs(im(a))+Abs(im(b)), dco), # 4.7.79 + (cos(a*t)*cos(b*t), s*(s**2+a**2+b**2)/(s**2+(a+b)**2)/(s**2+(a-b)**2), + S.true, Abs(im(a))+Abs(im(b)), dco), # 4.7.80 + (sinh(a*t), a/(s**2-a**2), + S.true, Abs(re(a)), dco), # 4.9.1 + (cosh(a*t), s/(s**2-a**2), + S.true, Abs(re(a)), dco), # 4.9.2 + (sinh(a*t)**2, 2*a**2/(s**3-4*a**2*s), + S.true, 2*Abs(re(a)), dco), # 4.9.3 + (cosh(a*t)**2, (s**2-2*a**2)/(s**3-4*a**2*s), + S.true, 2*Abs(re(a)), dco), # 4.9.4 + (sinh(a*t)/t, log((s+a)/(s-a))/2, + S.true, Abs(re(a)), dco), # 4.9.12 + (t**n*sinh(a*t), gamma(n+1)/2*((s-a)**(-n-1)-(s+a)**(-n-1)), + n > -2, Abs(a), dco), # 4.9.18 + (t**n*cosh(a*t), gamma(n+1)/2*((s-a)**(-n-1)+(s+a)**(-n-1)), + n > -1, Abs(a), dco), # 4.9.19 + # TODO + # (sinh(2*sqrt(a*t)), sqrt(pi*a)/s/sqrt(s)*exp(a/s), + # S.true, S.Zero, dco), # 4.9.34 + # (cosh(2*sqrt(a*t)), 1/s+sqrt(pi*a)/s/sqrt(s)*exp(a/s)*erf(sqrt(a/s)), + # S.true, S.Zero, dco), # 4.9.35 + # ( + # sqrt(t)*sinh(2*sqrt(a*t)), + # pi**(S(1)/2)*s**(-S(5)/2)*(s/2+a) * + # exp(a/s)*erf(sqrt(a/s))-a**(S(1)/2)*s**(-2), + # S.true, S.Zero, dco), # 4.9.36 + # (sqrt(t)*cosh(2*sqrt(a*t)), pi**(S(1)/2)*s**(-S(5)/2)*(s/2+a)*exp(a/s), + # S.true, S.Zero, dco), # 4.9.37 + # (sinh(2*sqrt(a*t))/sqrt(t), + # pi**(S(1)/2)*s**(-S(1)/2)*exp(a/s) * erf(sqrt(a/s)), + # S.true, S.Zero, dco), # 4.9.38 + # (cosh(2*sqrt(a*t))/sqrt(t), pi**(S(1)/2)*s**(-S(1)/2)*exp(a/s), + # S.true, S.Zero, dco), # 4.9.39 + # (sinh(sqrt(a*t))**2/sqrt(t), pi**(S(1)/2)/2*s**(-S(1)/2)*(exp(a/s)-1), + # S.true, S.Zero, dco), # 4.9.40 + # (cosh(sqrt(a*t))**2/sqrt(t), pi**(S(1)/2)/2*s**(-S(1)/2)*(exp(a/s)+1), + # S.true, S.Zero, dco), # 4.9.41 + (erf(a*t), exp(s**2/(2*a)**2)*erfc(s/(2*a))/s, + 4*Abs(arg(a)) < pi, S.Zero, dco), # 4.12.2 + # (erf(sqrt(a*t)), sqrt(a)/sqrt(s+a)/s, + # S.true, Max(S.Zero, -re(a)), dco), # 4.12.4 + # (exp(a*t)*erf(sqrt(a*t)), sqrt(a)/sqrt(s)/(s-a), + # S.true, Max(S.Zero, re(a)), dco), # 4.12.5 + # (erf(sqrt(a/t)/2), (1-exp(-sqrt(a*s)))/s, + # re(a) > 0, S.Zero, dco), # 4.12.6 + # (erfc(sqrt(a*t)), (sqrt(s+a)-sqrt(a))/sqrt(s+a)/s, + # S.true, -re(a), dco), # 4.12.9 + # (exp(a*t)*erfc(sqrt(a*t)), 1/(s+sqrt(a*s)), + # S.true, S.Zero, dco), # 4.12.10 + # (erfc(sqrt(a/t)/2), exp(-sqrt(a*s))/s, + # re(a) > 0, S.Zero, dco), # 4.2.11 + (besselj(n, a*t), a**n/(sqrt(s**2+a**2)*(s+sqrt(s**2+a**2))**n), + re(n) > -1, Abs(im(a)), dco), # 4.14.1 + (t**b*besselj(n, a*t), + 2**n/sqrt(pi)*gamma(n+S.Half)*a**n*(s**2+a**2)**(-n-S.Half), + And(re(n) > -S.Half, Eq(b, n)), Abs(im(a)), dco), # 4.14.7 + (t**b*besselj(n, a*t), + 2**(n+1)/sqrt(pi)*gamma(n+S(3)/2)*a**n*s*(s**2+a**2)**(-n-S(3)/2), + And(re(n) > -1, Eq(b, n+1)), Abs(im(a)), dco), # 4.14.8 + # (besselj(0, 2*sqrt(a*t)), exp(-a/s)/s, + # S.true, S.Zero, dco), # 4.14.25 + # (t**(b)*besselj(n, 2*sqrt(a*t)), a**(n/2)*s**(-n-1)*exp(-a/s), + # And(re(n) > -1, Eq(b, n*S.Half)), S.Zero, dco), # 4.14.30 + (besselj(0, a*sqrt(t**2+b*t)), + exp(b*s-b*sqrt(s**2+a**2))/sqrt(s**2+a**2), + Abs(arg(b)) < pi, Abs(im(a)), dco), # 4.15.19 + (besseli(n, a*t), a**n/(sqrt(s**2-a**2)*(s+sqrt(s**2-a**2))**n), + re(n) > -1, Abs(re(a)), dco), # 4.16.1 + (t**b*besseli(n, a*t), + 2**n/sqrt(pi)*gamma(n+S.Half)*a**n*(s**2-a**2)**(-n-S.Half), + And(re(n) > -S.Half, Eq(b, n)), Abs(re(a)), dco), # 4.16.6 + (t**b*besseli(n, a*t), + 2**(n+1)/sqrt(pi)*gamma(n+S(3)/2)*a**n*s*(s**2-a**2)**(-n-S(3)/2), + And(re(n) > -1, Eq(b, n+1)), Abs(re(a)), dco), # 4.16.7 + # (t**(b)*besseli(n, 2*sqrt(a*t)), a**(n/2)*s**(-n-1)*exp(a/s), + # And(re(n) > -1, Eq(b, n*S.Half)), S.Zero, dco), # 4.16.18 + (bessely(0, a*t), -2/pi*asinh(s/a)/sqrt(s**2+a**2), + S.true, Abs(im(a)), dco), # 4.15.44 + (besselk(0, a*t), log((s + sqrt(s**2-a**2))/a)/(sqrt(s**2-a**2)), + S.true, -re(a), dco) # 4.16.23 + ] + return laplace_transform_rules, t, s + + +@DEBUG_WRAP +def _laplace_rule_timescale(f, t, s): + """ + This function applies the time-scaling rule of the Laplace transform in + a straight-forward way. For example, if it gets ``(f(a*t), t, s)``, it will + compute ``LaplaceTransform(f(t)/a, t, s/a)`` if ``a>0``. + """ + + a = Wild('a', exclude=[t]) + g = WildFunction('g', nargs=1) + ma1 = f.match(g) + if ma1: + arg = ma1[g].args[0].collect(t) + ma2 = arg.match(a*t) + if ma2 and ma2[a].is_positive and ma2[a] != 1: + _debug(' rule: time scaling (4.1.4)') + r, pr, cr = _laplace_transform( + 1/ma2[a]*ma1[g].func(t), t, s/ma2[a], simplify=False) + return (r, pr, cr) + return None + + +@DEBUG_WRAP +def _laplace_rule_heaviside(f, t, s): + """ + This function deals with time-shifted Heaviside step functions. If the time + shift is positive, it applies the time-shift rule of the Laplace transform. + For example, if it gets ``(Heaviside(t-a)*f(t), t, s)``, it will compute + ``exp(-a*s)*LaplaceTransform(f(t+a), t, s)``. + + If the time shift is negative, the Heaviside function is simply removed + as it means nothing to the Laplace transform. + + The function does not remove a factor ``Heaviside(t)``; this is done by + the simple rules. + """ + + a = Wild('a', exclude=[t]) + y = Wild('y') + g = Wild('g') + if ma1 := f.match(Heaviside(y) * g): + if ma2 := ma1[y].match(t - a): + if ma2[a].is_positive: + _debug(' rule: time shift (4.1.4)') + r, pr, cr = _laplace_transform( + ma1[g].subs(t, t + ma2[a]), t, s, simplify=False) + return (exp(-ma2[a] * s) * r, pr, cr) + if ma2[a].is_negative: + _debug( + ' rule: Heaviside factor; negative time shift (4.1.4)') + r, pr, cr = _laplace_transform(ma1[g], t, s, simplify=False) + return (r, pr, cr) + if ma2 := ma1[y].match(a - t): + if ma2[a].is_positive: + _debug(' rule: Heaviside window open') + r, pr, cr = _laplace_transform( + (1 - Heaviside(t - ma2[a])) * ma1[g], t, s, simplify=False) + return (r, pr, cr) + if ma2[a].is_negative: + _debug(' rule: Heaviside window closed') + return (0, 0, S.true) + return None + + +@DEBUG_WRAP +def _laplace_rule_exp(f, t, s): + """ + If this function finds a factor ``exp(a*t)``, it applies the + frequency-shift rule of the Laplace transform and adjusts the convergence + plane accordingly. For example, if it gets ``(exp(-a*t)*f(t), t, s)``, it + will compute ``LaplaceTransform(f(t), t, s+a)``. + """ + + a = Wild('a', exclude=[t]) + y = Wild('y') + z = Wild('z') + ma1 = f.match(exp(y)*z) + if ma1: + ma2 = ma1[y].collect(t).match(a*t) + if ma2: + _debug(' rule: multiply with exp (4.1.5)') + r, pr, cr = _laplace_transform(ma1[z], t, s-ma2[a], + simplify=False) + return (r, pr+re(ma2[a]), cr) + return None + + +@DEBUG_WRAP +def _laplace_rule_delta(f, t, s): + """ + If this function finds a factor ``DiracDelta(b*t-a)``, it applies the + masking property of the delta distribution. For example, if it gets + ``(DiracDelta(t-a)*f(t), t, s)``, it will return + ``(f(a)*exp(-a*s), -a, True)``. + """ + # This rule is not in Bateman54 + + a = Wild('a', exclude=[t]) + b = Wild('b', exclude=[t]) + + y = Wild('y') + z = Wild('z') + ma1 = f.match(DiracDelta(y)*z) + if ma1 and not ma1[z].has(DiracDelta): + ma2 = ma1[y].collect(t).match(b*t-a) + if ma2: + _debug(' rule: multiply with DiracDelta') + loc = ma2[a]/ma2[b] + if re(loc) >= 0 and im(loc) == 0: + fn = exp(-ma2[a]/ma2[b]*s)*ma1[z] + if fn.has(sin, cos): + # Then it may be possible that a sinc() is present in the + # term; let's try this: + fn = fn.rewrite(sinc).ratsimp() + n, d = [x.subs(t, ma2[a]/ma2[b]) for x in fn.as_numer_denom()] + if d != 0: + return (n/d/ma2[b], S.NegativeInfinity, S.true) + else: + return None + else: + return (0, S.NegativeInfinity, S.true) + if ma1[y].is_polynomial(t): + ro = roots(ma1[y], t) + if ro != {} and set(ro.values()) == {1}: + slope = diff(ma1[y], t) + r = Add( + *[exp(-x*s)*ma1[z].subs(t, s)/slope.subs(t, x) + for x in list(ro.keys()) if im(x) == 0 and re(x) >= 0]) + return (r, S.NegativeInfinity, S.true) + return None + + +@DEBUG_WRAP +def _laplace_trig_split(fn): + """ + Helper function for `_laplace_rule_trig`. This function returns two terms + `f` and `g`. `f` contains all product terms with sin, cos, sinh, cosh in + them; `g` contains everything else. + """ + trigs = [S.One] + other = [S.One] + for term in Mul.make_args(fn): + if term.has(sin, cos, sinh, cosh, exp): + trigs.append(term) + else: + other.append(term) + f = Mul(*trigs) + g = Mul(*other) + return f, g + + +@DEBUG_WRAP +def _laplace_trig_expsum(f, t): + """ + Helper function for `_laplace_rule_trig`. This function expects the `f` + from `_laplace_trig_split`. It returns two lists `xm` and `xn`. `xm` is + a list of dictionaries with keys `k` and `a` representing a function + `k*exp(a*t)`. `xn` is a list of all terms that cannot be brought into + that form, which may happen, e.g., when a trigonometric function has + another function in its argument. + """ + c1 = Wild('c1', exclude=[t]) + c0 = Wild('c0', exclude=[t]) + p = Wild('p', exclude=[t]) + xm = [] + xn = [] + + x1 = f.rewrite(exp).expand() + + for term in Add.make_args(x1): + if not term.has(t): + xm.append({'k': term, 'a': 0, re: 0, im: 0}) + continue + term = _laplace_deep_collect(term.powsimp(combine='exp'), t) + + if (r := term.match(p*exp(c1*t+c0))) is not None: + xm.append({ + 'k': r[p]*exp(r[c0]), 'a': r[c1], + re: re(r[c1]), im: im(r[c1])}) + else: + xn.append(term) + return xm, xn + + +@DEBUG_WRAP +def _laplace_trig_ltex(xm, t, s): + """ + Helper function for `_laplace_rule_trig`. This function takes the list of + exponentials `xm` from `_laplace_trig_expsum` and simplifies complex + conjugate and real symmetric poles. It returns the result as a sum and + the convergence plane. + """ + results = [] + planes = [] + + def _simpc(coeffs): + nc = coeffs.copy() + for k in range(len(nc)): + ri = nc[k].as_real_imag() + if ri[0].has(im): + nc[k] = nc[k].rewrite(cos) + else: + nc[k] = (ri[0] + I*ri[1]).rewrite(cos) + return nc + + def _quadpole(t1, k1, k2, k3, s): + a, k0, a_r, a_i = t1['a'], t1['k'], t1[re], t1[im] + nc = [ + k0 + k1 + k2 + k3, + a*(k0 + k1 - k2 - k3) - 2*I*a_i*k1 + 2*I*a_i*k2, + ( + a**2*(-k0 - k1 - k2 - k3) + + a*(4*I*a_i*k0 + 4*I*a_i*k3) + + 4*a_i**2*k0 + 4*a_i**2*k3), + ( + a**3*(-k0 - k1 + k2 + k3) + + a**2*(4*I*a_i*k0 + 2*I*a_i*k1 - 2*I*a_i*k2 - 4*I*a_i*k3) + + a*(4*a_i**2*k0 - 4*a_i**2*k3)) + ] + dc = [ + S.One, S.Zero, 2*a_i**2 - 2*a_r**2, + S.Zero, a_i**4 + 2*a_i**2*a_r**2 + a_r**4] + n = Add( + *[x*s**y for x, y in zip(_simpc(nc), range(len(nc))[::-1])]) + d = Add( + *[x*s**y for x, y in zip(dc, range(len(dc))[::-1])]) + return n/d + + def _ccpole(t1, k1, s): + a, k0, a_r, a_i = t1['a'], t1['k'], t1[re], t1[im] + nc = [k0 + k1, -a*k0 - a*k1 + 2*I*a_i*k0] + dc = [S.One, -2*a_r, a_i**2 + a_r**2] + n = Add( + *[x*s**y for x, y in zip(_simpc(nc), range(len(nc))[::-1])]) + d = Add( + *[x*s**y for x, y in zip(dc, range(len(dc))[::-1])]) + return n/d + + def _rspole(t1, k2, s): + a, k0, a_r, a_i = t1['a'], t1['k'], t1[re], t1[im] + nc = [k0 + k2, a*k0 - a*k2 - 2*I*a_i*k0] + dc = [S.One, -2*I*a_i, -a_i**2 - a_r**2] + n = Add( + *[x*s**y for x, y in zip(_simpc(nc), range(len(nc))[::-1])]) + d = Add( + *[x*s**y for x, y in zip(dc, range(len(dc))[::-1])]) + return n/d + + def _sypole(t1, k3, s): + a, k0 = t1['a'], t1['k'] + nc = [k0 + k3, a*(k0 - k3)] + dc = [S.One, S.Zero, -a**2] + n = Add( + *[x*s**y for x, y in zip(_simpc(nc), range(len(nc))[::-1])]) + d = Add( + *[x*s**y for x, y in zip(dc, range(len(dc))[::-1])]) + return n/d + + def _simplepole(t1, s): + a, k0 = t1['a'], t1['k'] + n = k0 + d = s - a + return n/d + + while len(xm) > 0: + t1 = xm.pop() + i_imagsym = None + i_realsym = None + i_pointsym = None + # The following code checks all remaining poles. If t1 is a pole at + # a+b*I, then we check for a-b*I, -a+b*I, and -a-b*I, and + # assign the respective indices to i_imagsym, i_realsym, i_pointsym. + # -a-b*I / i_pointsym only applies if both a and b are != 0. + for i in range(len(xm)): + real_eq = t1[re] == xm[i][re] + realsym = t1[re] == -xm[i][re] + imag_eq = t1[im] == xm[i][im] + imagsym = t1[im] == -xm[i][im] + if realsym and imagsym and t1[re] != 0 and t1[im] != 0: + i_pointsym = i + elif realsym and imag_eq and t1[re] != 0: + i_realsym = i + elif real_eq and imagsym and t1[im] != 0: + i_imagsym = i + + # The next part looks for four possible pole constellations: + # quad: a+b*I, a-b*I, -a+b*I, -a-b*I + # cc: a+b*I, a-b*I (a may be zero) + # quad: a+b*I, -a+b*I (b may be zero) + # point: a+b*I, -a-b*I (a!=0 and b!=0 is needed, but that has been + # asserted when finding i_pointsym above.) + # If none apply, then t1 is a simple pole. + if ( + i_imagsym is not None and i_realsym is not None + and i_pointsym is not None): + results.append( + _quadpole(t1, + xm[i_imagsym]['k'], xm[i_realsym]['k'], + xm[i_pointsym]['k'], s)) + planes.append(Abs(re(t1['a']))) + # The three additional poles have now been used; to pop them + # easily we have to do it from the back. + indices_to_pop = [i_imagsym, i_realsym, i_pointsym] + indices_to_pop.sort(reverse=True) + for i in indices_to_pop: + xm.pop(i) + elif i_imagsym is not None: + results.append(_ccpole(t1, xm[i_imagsym]['k'], s)) + planes.append(t1[re]) + xm.pop(i_imagsym) + elif i_realsym is not None: + results.append(_rspole(t1, xm[i_realsym]['k'], s)) + planes.append(Abs(t1[re])) + xm.pop(i_realsym) + elif i_pointsym is not None: + results.append(_sypole(t1, xm[i_pointsym]['k'], s)) + planes.append(Abs(t1[re])) + xm.pop(i_pointsym) + else: + results.append(_simplepole(t1, s)) + planes.append(t1[re]) + + return Add(*results), Max(*planes) + + +@DEBUG_WRAP +def _laplace_rule_trig(fn, t_, s): + """ + This rule covers trigonometric factors by splitting everything into a + sum of exponential functions and collecting complex conjugate poles and + real symmetric poles. + """ + t = Dummy('t', real=True) + + if not fn.has(sin, cos, sinh, cosh): + return None + + f, g = _laplace_trig_split(fn.subs(t_, t)) + xm, xn = _laplace_trig_expsum(f, t) + + if len(xn) > 0: + # TODO not implemented yet, but also not important + return None + + if not g.has(t): + r, p = _laplace_trig_ltex(xm, t, s) + return g*r, p, S.true + else: + # Just transform `g` and make frequency-shifted copies + planes = [] + results = [] + G, G_plane, G_cond = _laplace_transform(g, t, s, simplify=False) + for x1 in xm: + results.append(x1['k']*G.subs(s, s-x1['a'])) + planes.append(G_plane+re(x1['a'])) + return Add(*results).subs(t, t_), Max(*planes), G_cond + + +@DEBUG_WRAP +def _laplace_rule_diff(f, t, s): + """ + This function looks for derivatives in the time domain and replaces it + by factors of `s` and initial conditions in the frequency domain. For + example, if it gets ``(diff(f(t), t), t, s)``, it will compute + ``s*LaplaceTransform(f(t), t, s) - f(0)``. + """ + + a = Wild('a', exclude=[t]) + n = Wild('n', exclude=[t]) + g = WildFunction('g') + ma1 = f.match(a*Derivative(g, (t, n))) + if ma1 and ma1[n].is_integer: + m = [z.has(t) for z in ma1[g].args] + if sum(m) == 1: + _debug(' rule: time derivative (4.1.8)') + d = [] + for k in range(ma1[n]): + if k == 0: + y = ma1[g].subs(t, 0) + else: + y = Derivative(ma1[g], (t, k)).subs(t, 0) + d.append(s**(ma1[n]-k-1)*y) + r, pr, cr = _laplace_transform(ma1[g], t, s, simplify=False) + return (ma1[a]*(s**ma1[n]*r - Add(*d)), pr, cr) + return None + + +@DEBUG_WRAP +def _laplace_rule_sdiff(f, t, s): + """ + This function looks for multiplications with polynoimials in `t` as they + correspond to differentiation in the frequency domain. For example, if it + gets ``(t*f(t), t, s)``, it will compute + ``-Derivative(LaplaceTransform(f(t), t, s), s)``. + """ + + if f.is_Mul: + pfac = [1] + ofac = [1] + for fac in Mul.make_args(f): + if fac.is_polynomial(t): + pfac.append(fac) + else: + ofac.append(fac) + if len(pfac) > 1: + pex = prod(pfac) + pc = Poly(pex, t).all_coeffs() + N = len(pc) + if N > 1: + oex = prod(ofac) + r_, p_, c_ = _laplace_transform(oex, t, s, simplify=False) + deri = [r_] + d1 = False + try: + d1 = -diff(deri[-1], s) + except ValueError: + d1 = False + if r_.has(LaplaceTransform): + for k in range(N-1): + deri.append((-1)**(k+1)*Derivative(r_, s, k+1)) + elif d1: + deri.append(d1) + for k in range(N-2): + deri.append(-diff(deri[-1], s)) + if d1: + r = Add(*[pc[N-n-1]*deri[n] for n in range(N)]) + return (r, p_, c_) + # We still have to cover the possibility that there is a symbolic positive + # integer exponent. + n = Wild('n', exclude=[t]) + g = Wild('g') + if ma1 := f.match(t**n*g): + if ma1[n].is_integer and ma1[n].is_positive: + r_, p_, c_ = _laplace_transform(ma1[g], t, s, simplify=False) + return (-1)**ma1[n]*diff(r_, (s, ma1[n])), p_, c_ + return None + + +@DEBUG_WRAP +def _laplace_expand(f, t, s): + """ + This function tries to expand its argument with successively stronger + methods: first it will expand on the top level, then it will expand any + multiplications in depth, then it will try all available expansion methods, + and finally it will try to expand trigonometric functions. + + If it can expand, it will then compute the Laplace transform of the + expanded term. + """ + + r = expand(f, deep=False) + if r.is_Add: + return _laplace_transform(r, t, s, simplify=False) + r = expand_mul(f) + if r.is_Add: + return _laplace_transform(r, t, s, simplify=False) + r = expand(f) + if r.is_Add: + return _laplace_transform(r, t, s, simplify=False) + if r != f: + return _laplace_transform(r, t, s, simplify=False) + r = expand(expand_trig(f)) + if r.is_Add: + return _laplace_transform(r, t, s, simplify=False) + return None + + +@DEBUG_WRAP +def _laplace_apply_prog_rules(f, t, s): + """ + This function applies all program rules and returns the result if one + of them gives a result. + """ + + prog_rules = [_laplace_rule_heaviside, _laplace_rule_delta, + _laplace_rule_timescale, _laplace_rule_exp, + _laplace_rule_trig, + _laplace_rule_diff, _laplace_rule_sdiff] + + for p_rule in prog_rules: + if (L := p_rule(f, t, s)) is not None: + return L + return None + + +@DEBUG_WRAP +def _laplace_apply_simple_rules(f, t, s): + """ + This function applies all simple rules and returns the result if one + of them gives a result. + """ + simple_rules, t_, s_ = _laplace_build_rules() + prep_old = '' + prep_f = '' + for t_dom, s_dom, check, plane, prep in simple_rules: + if prep_old != prep: + prep_f = prep(f.subs({t: t_})) + prep_old = prep + ma = prep_f.match(t_dom) + if ma: + try: + c = check.xreplace(ma) + except TypeError: + # This may happen if the time function has imaginary + # numbers in it. Then we give up. + continue + if c == S.true: + return (s_dom.xreplace(ma).subs({s_: s}), + plane.xreplace(ma), S.true) + return None + + +@DEBUG_WRAP +def _piecewise_to_heaviside(f, t): + """ + This function converts a Piecewise expression to an expression written + with Heaviside. It is not exact, but valid in the context of the Laplace + transform. + """ + if not t.is_real: + r = Dummy('r', real=True) + return _piecewise_to_heaviside(f.xreplace({t: r}), r).xreplace({r: t}) + x = piecewise_exclusive(f) + r = [] + for fn, cond in x.args: + # Here we do not need to do many checks because piecewise_exclusive + # has a clearly predictable output. However, if any of the conditions + # is not relative to t, this function just returns the input argument. + if isinstance(cond, Relational) and t in cond.args: + if isinstance(cond, (Eq, Ne)): + # We do not cover this case; these would be single-point + # exceptions that do not play a role in Laplace practice, + # except if they contain Dirac impulses, and then we can + # expect users to not try to use Piecewise for writing it. + return f + else: + r.append(Heaviside(cond.gts - cond.lts)*fn) + elif isinstance(cond, Or) and len(cond.args) == 2: + # Or(t<2, t>4), Or(t>4, t<=2), ... in any order with any <= >= + for c2 in cond.args: + if c2.lhs == t: + r.append(Heaviside(c2.gts - c2.lts)*fn) + else: + return f + elif isinstance(cond, And) and len(cond.args) == 2: + # And(t>2, t<4), And(t>4, t<=2), ... in any order with any <= >= + c0, c1 = cond.args + if c0.lhs == t and c1.lhs == t: + if '>' in c0.rel_op: + c0, c1 = c1, c0 + r.append( + (Heaviside(c1.gts - c1.lts) - + Heaviside(c0.lts - c0.gts))*fn) + else: + return f + else: + return f + return Add(*r) + + +def laplace_correspondence(f, fdict, /): + """ + This helper function takes a function `f` that is the result of a + ``laplace_transform`` or an ``inverse_laplace_transform``. It replaces all + unevaluated ``LaplaceTransform(y(t), t, s)`` by `Y(s)` for any `s` and + all ``InverseLaplaceTransform(Y(s), s, t)`` by `y(t)` for any `t` if + ``fdict`` contains a correspondence ``{y: Y}``. + + Parameters + ========== + + f : sympy expression + Expression containing unevaluated ``LaplaceTransform`` or + ``LaplaceTransform`` objects. + fdict : dictionary + Dictionary containing one or more function correspondences, + e.g., ``{x: X, y: Y}`` meaning that ``X`` and ``Y`` are the + Laplace transforms of ``x`` and ``y``, respectively. + + Examples + ======== + + >>> from sympy import laplace_transform, diff, Function + >>> from sympy import laplace_correspondence, inverse_laplace_transform + >>> from sympy.abc import t, s + >>> y = Function("y") + >>> Y = Function("Y") + >>> z = Function("z") + >>> Z = Function("Z") + >>> f = laplace_transform(diff(y(t), t, 1) + z(t), t, s, noconds=True) + >>> laplace_correspondence(f, {y: Y, z: Z}) + s*Y(s) + Z(s) - y(0) + >>> f = inverse_laplace_transform(Y(s), s, t) + >>> laplace_correspondence(f, {y: Y}) + y(t) + """ + p = Wild('p') + s = Wild('s') + t = Wild('t') + a = Wild('a') + if ( + not isinstance(f, Expr) + or (not f.has(LaplaceTransform) + and not f.has(InverseLaplaceTransform))): + return f + for y, Y in fdict.items(): + if ( + (m := f.match(LaplaceTransform(y(a), t, s))) is not None + and m[a] == m[t]): + return Y(m[s]) + if ( + (m := f.match(InverseLaplaceTransform(Y(a), s, t, p))) + is not None + and m[a] == m[s]): + return y(m[t]) + func = f.func + args = [laplace_correspondence(arg, fdict) for arg in f.args] + return func(*args) + + +def laplace_initial_conds(f, t, fdict, /): + """ + This helper function takes a function `f` that is the result of a + ``laplace_transform``. It takes an fdict of the form ``{y: [1, 4, 2]}``, + where the values in the list are the initial value, the initial slope, the + initial second derivative, etc., of the function `y(t)`, and replaces all + unevaluated initial conditions. + + Parameters + ========== + + f : sympy expression + Expression containing initial conditions of unevaluated functions. + t : sympy expression + Variable for which the initial conditions are to be applied. + fdict : dictionary + Dictionary containing a list of initial conditions for every + function, e.g., ``{y: [0, 1, 2], x: [3, 4, 5]}``. The order + of derivatives is ascending, so `0`, `1`, `2` are `y(0)`, `y'(0)`, + and `y''(0)`, respectively. + + Examples + ======== + + >>> from sympy import laplace_transform, diff, Function + >>> from sympy import laplace_correspondence, laplace_initial_conds + >>> from sympy.abc import t, s + >>> y = Function("y") + >>> Y = Function("Y") + >>> f = laplace_transform(diff(y(t), t, 3), t, s, noconds=True) + >>> g = laplace_correspondence(f, {y: Y}) + >>> laplace_initial_conds(g, t, {y: [2, 4, 8, 16, 32]}) + s**3*Y(s) - 2*s**2 - 4*s - 8 + """ + for y, ic in fdict.items(): + for k in range(len(ic)): + if k == 0: + f = f.replace(y(0), ic[0]) + elif k == 1: + f = f.replace(Subs(Derivative(y(t), t), t, 0), ic[1]) + else: + f = f.replace(Subs(Derivative(y(t), (t, k)), t, 0), ic[k]) + return f + + +@DEBUG_WRAP +def _laplace_transform(fn, t_, s_, *, simplify): + """ + Front-end function of the Laplace transform. It tries to apply all known + rules recursively, and if everything else fails, it tries to integrate. + """ + + terms_t = Add.make_args(fn) + terms_s = [] + terms = [] + planes = [] + conditions = [] + + for ff in terms_t: + k, ft = ff.as_independent(t_, as_Add=False) + if ft.has(SingularityFunction): + _terms = Add.make_args(ft.rewrite(Heaviside)) + for _term in _terms: + k1, f1 = _term.as_independent(t_, as_Add=False) + terms.append((k*k1, f1)) + elif ft.func == Piecewise and not ft.has(DiracDelta(t_)): + _terms = Add.make_args(_piecewise_to_heaviside(ft, t_)) + for _term in _terms: + k1, f1 = _term.as_independent(t_, as_Add=False) + terms.append((k*k1, f1)) + else: + terms.append((k, ft)) + + for k, ft in terms: + if ft.has(SingularityFunction): + r = (LaplaceTransform(ft, t_, s_), S.NegativeInfinity, True) + else: + if ft.has(Heaviside(t_)) and not ft.has(DiracDelta(t_)): + # For t>=0, Heaviside(t)=1 can be used, except if there is also + # a DiracDelta(t) present, in which case removing Heaviside(t) + # is unnecessary because _laplace_rule_delta can deal with it. + ft = ft.subs(Heaviside(t_), 1) + if ( + (r := _laplace_apply_simple_rules(ft, t_, s_)) + is not None or + (r := _laplace_apply_prog_rules(ft, t_, s_)) + is not None or + (r := _laplace_expand(ft, t_, s_)) is not None): + pass + elif any(undef.has(t_) for undef in ft.atoms(AppliedUndef)): + # If there are undefined functions f(t) then integration is + # unlikely to do anything useful so we skip it and given an + # unevaluated LaplaceTransform. + r = (LaplaceTransform(ft, t_, s_), S.NegativeInfinity, True) + elif (r := _laplace_transform_integration( + ft, t_, s_, simplify=simplify)) is not None: + pass + else: + r = (LaplaceTransform(ft, t_, s_), S.NegativeInfinity, True) + (ri_, pi_, ci_) = r + terms_s.append(k*ri_) + planes.append(pi_) + conditions.append(ci_) + + result = Add(*terms_s) + if simplify: + result = result.simplify(doit=False) + plane = Max(*planes) + condition = And(*conditions) + + return result, plane, condition + + +class LaplaceTransform(IntegralTransform): + """ + Class representing unevaluated Laplace transforms. + + For usage of this class, see the :class:`IntegralTransform` docstring. + + For how to compute Laplace transforms, see the :func:`laplace_transform` + docstring. + + If this is called with ``.doit()``, it returns the Laplace transform as an + expression. If it is called with ``.doit(noconds=False)``, it returns a + tuple containing the same expression, a convergence plane, and conditions. + """ + + _name = 'Laplace' + + def _compute_transform(self, f, t, s, **hints): + _simplify = hints.get('simplify', False) + LT = _laplace_transform_integration(f, t, s, simplify=_simplify) + return LT + + def _as_integral(self, f, t, s): + return Integral(f*exp(-s*t), (t, S.Zero, S.Infinity)) + + def doit(self, **hints): + """ + Try to evaluate the transform in closed form. + + Explanation + =========== + + Standard hints are the following: + - ``noconds``: if True, do not return convergence conditions. The + default setting is `True`. + - ``simplify``: if True, it simplifies the final result. The + default setting is `False`. + """ + _noconds = hints.get('noconds', True) + _simplify = hints.get('simplify', False) + + debugf('[LT doit] (%s, %s, %s)', (self.function, + self.function_variable, + self.transform_variable)) + + t_ = self.function_variable + s_ = self.transform_variable + fn = self.function + + r = _laplace_transform(fn, t_, s_, simplify=_simplify) + + if _noconds: + return r[0] + else: + return r + + +def laplace_transform(f, t, s, legacy_matrix=True, **hints): + r""" + Compute the Laplace Transform `F(s)` of `f(t)`, + + .. math :: F(s) = \int_{0^{-}}^\infty e^{-st} f(t) \mathrm{d}t. + + Explanation + =========== + + For all sensible functions, this converges absolutely in a + half-plane + + .. math :: a < \operatorname{Re}(s) + + This function returns ``(F, a, cond)`` where ``F`` is the Laplace + transform of ``f``, `a` is the half-plane of convergence, and `cond` are + auxiliary convergence conditions. + + The implementation is rule-based, and if you are interested in which + rules are applied, and whether integration is attempted, you can switch + debug information on by setting ``sympy.SYMPY_DEBUG=True``. The numbers + of the rules in the debug information (and the code) refer to Bateman's + Tables of Integral Transforms [1]. + + The lower bound is `0-`, meaning that this bound should be approached + from the lower side. This is only necessary if distributions are involved. + At present, it is only done if `f(t)` contains ``DiracDelta``, in which + case the Laplace transform is computed implicitly as + + .. math :: + F(s) = \lim_{\tau\to 0^{-}} \int_{\tau}^\infty e^{-st} + f(t) \mathrm{d}t + + by applying rules. + + If the Laplace transform cannot be fully computed in closed form, this + function returns expressions containing unevaluated + :class:`LaplaceTransform` objects. + + For a description of possible hints, refer to the docstring of + :func:`sympy.integrals.transforms.IntegralTransform.doit`. If + ``noconds=True``, only `F` will be returned (i.e. not ``cond``, and also + not the plane ``a``). + + .. deprecated:: 1.9 + Legacy behavior for matrices where ``laplace_transform`` with + ``noconds=False`` (the default) returns a Matrix whose elements are + tuples. The behavior of ``laplace_transform`` for matrices will change + in a future release of SymPy to return a tuple of the transformed + Matrix and the convergence conditions for the matrix as a whole. Use + ``legacy_matrix=False`` to enable the new behavior. + + Examples + ======== + + >>> from sympy import DiracDelta, exp, laplace_transform + >>> from sympy.abc import t, s, a + >>> laplace_transform(t**4, t, s) + (24/s**5, 0, True) + >>> laplace_transform(t**a, t, s) + (gamma(a + 1)/(s*s**a), 0, re(a) > -1) + >>> laplace_transform(DiracDelta(t)-a*exp(-a*t), t, s, simplify=True) + (s/(a + s), -re(a), True) + + There are also helper functions that make it easy to solve differential + equations by Laplace transform. For example, to solve + + .. math :: m x''(t) + d x'(t) + k x(t) = 0 + + with initial value `0` and initial derivative `v`: + + >>> from sympy import Function, laplace_correspondence, diff, solve + >>> from sympy import laplace_initial_conds, inverse_laplace_transform + >>> from sympy.abc import d, k, m, v + >>> x = Function('x') + >>> X = Function('X') + >>> f = m*diff(x(t), t, 2) + d*diff(x(t), t) + k*x(t) + >>> F = laplace_transform(f, t, s, noconds=True) + >>> F = laplace_correspondence(F, {x: X}) + >>> F = laplace_initial_conds(F, t, {x: [0, v]}) + >>> F + d*s*X(s) + k*X(s) + m*(s**2*X(s) - v) + >>> Xs = solve(F, X(s))[0] + >>> Xs + m*v/(d*s + k + m*s**2) + >>> inverse_laplace_transform(Xs, s, t) + 2*v*exp(-d*t/(2*m))*sin(t*sqrt((-d**2 + 4*k*m)/m**2)/2)*Heaviside(t)/sqrt((-d**2 + 4*k*m)/m**2) + + References + ========== + + .. [1] Erdelyi, A. (ed.), Tables of Integral Transforms, Volume 1, + Bateman Manuscript Prooject, McGraw-Hill (1954), available: + https://resolver.caltech.edu/CaltechAUTHORS:20140123-101456353 + + See Also + ======== + + inverse_laplace_transform, mellin_transform, fourier_transform + hankel_transform, inverse_hankel_transform + + """ + + _noconds = hints.get('noconds', False) + _simplify = hints.get('simplify', False) + + if isinstance(f, MatrixBase) and hasattr(f, 'applyfunc'): + + conds = not hints.get('noconds', False) + + if conds and legacy_matrix: + adt = 'deprecated-laplace-transform-matrix' + sympy_deprecation_warning( + """ +Calling laplace_transform() on a Matrix with noconds=False (the default) is +deprecated. Either noconds=True or use legacy_matrix=False to get the new +behavior. + """, + deprecated_since_version='1.9', + active_deprecations_target=adt, + ) + # Temporarily disable the deprecation warning for non-Expr objects + # in Matrix + with ignore_warnings(SymPyDeprecationWarning): + return f.applyfunc( + lambda fij: laplace_transform(fij, t, s, **hints)) + else: + elements_trans = [laplace_transform( + fij, t, s, **hints) for fij in f] + if conds: + elements, avals, conditions = zip(*elements_trans) + f_laplace = type(f)(*f.shape, elements) + return f_laplace, Max(*avals), And(*conditions) + else: + return type(f)(*f.shape, elements_trans) + + LT, p, c = LaplaceTransform(f, t, s).doit(noconds=False, + simplify=_simplify) + + if not _noconds: + return LT, p, c + else: + return LT + + +@DEBUG_WRAP +def _inverse_laplace_transform_integration(F, s, t_, plane, *, simplify): + """ The backend function for inverse Laplace transforms. """ + from sympy.integrals.meijerint import meijerint_inversion, _get_coeff_exp + from sympy.integrals.transforms import inverse_mellin_transform + + # There are two strategies we can try: + # 1) Use inverse mellin transform, related by a simple change of variables. + # 2) Use the inversion integral. + + t = Dummy('t', real=True) + + def pw_simp(*args): + """ Simplify a piecewise expression from hyperexpand. """ + if len(args) != 3: + return Piecewise(*args) + arg = args[2].args[0].argument + coeff, exponent = _get_coeff_exp(arg, t) + e1 = args[0].args[0] + e2 = args[1].args[0] + return ( + Heaviside(1/Abs(coeff) - t**exponent)*e1 + + Heaviside(t**exponent - 1/Abs(coeff))*e2) + + if F.is_rational_function(s): + F = F.apart(s) + + if F.is_Add: + f = Add( + *[_inverse_laplace_transform_integration(X, s, t, plane, simplify) + for X in F.args]) + return _simplify(f.subs(t, t_), simplify), True + + try: + f, cond = inverse_mellin_transform(F, s, exp(-t), (None, S.Infinity), + needeval=True, noconds=False) + except IntegralTransformError: + f = None + + if f is None: + f = meijerint_inversion(F, s, t) + if f is None: + return None + if f.is_Piecewise: + f, cond = f.args[0] + if f.has(Integral): + return None + else: + cond = S.true + f = f.replace(Piecewise, pw_simp) + + if f.is_Piecewise: + # many of the functions called below can't work with piecewise + # (b/c it has a bool in args) + return f.subs(t, t_), cond + + u = Dummy('u') + + def simp_heaviside(arg, H0=S.Half): + a = arg.subs(exp(-t), u) + if a.has(t): + return Heaviside(arg, H0) + from sympy.solvers.inequalities import _solve_inequality + rel = _solve_inequality(a > 0, u) + if rel.lts == u: + k = log(rel.gts) + return Heaviside(t + k, H0) + else: + k = log(rel.lts) + return Heaviside(-(t + k), H0) + + f = f.replace(Heaviside, simp_heaviside) + + def simp_exp(arg): + return expand_complex(exp(arg)) + + f = f.replace(exp, simp_exp) + + return _simplify(f.subs(t, t_), simplify), cond + + +@DEBUG_WRAP +def _complete_the_square_in_denom(f, s): + from sympy.simplify.radsimp import fraction + [n, d] = fraction(f) + if d.is_polynomial(s): + cf = d.as_poly(s).all_coeffs() + if len(cf) == 3: + a, b, c = cf + d = a*((s+b/(2*a))**2+c/a-(b/(2*a))**2) + return n/d + + +@cacheit +def _inverse_laplace_build_rules(): + """ + This is an internal helper function that returns the table of inverse + Laplace transform rules in terms of the time variable `t` and the + frequency variable `s`. It is used by `_inverse_laplace_apply_rules`. + """ + s = Dummy('s') + t = Dummy('t') + a = Wild('a', exclude=[s]) + b = Wild('b', exclude=[s]) + c = Wild('c', exclude=[s]) + + _debug('_inverse_laplace_build_rules is building rules') + + def _frac(f, s): + try: + return f.factor(s) + except PolynomialError: + return f + + def same(f): return f + # This list is sorted according to the prep function needed. + _ILT_rules = [ + (a/s, a, S.true, same, 1), + ( + b*(s+a)**(-c), t**(c-1)*exp(-a*t)/gamma(c), + S.true, same, 1), + (1/(s**2+a**2)**2, (sin(a*t) - a*t*cos(a*t))/(2*a**3), + S.true, same, 1), + # The next two rules must be there in that order. For the second + # one, the condition would be a != 0 or, respectively, to take the + # limit a -> 0 after the transform if a == 0. It is much simpler if + # the case a == 0 has its own rule. + (1/(s**b), t**(b - 1)/gamma(b), S.true, same, 1), + (1/(s*(s+a)**b), lowergamma(b, a*t)/(a**b*gamma(b)), + S.true, same, 1) + ] + return _ILT_rules, s, t + + +@DEBUG_WRAP +def _inverse_laplace_apply_simple_rules(f, s, t): + """ + Helper function for the class InverseLaplaceTransform. + """ + if f == 1: + _debug(' rule: 1 o---o DiracDelta()') + return DiracDelta(t), S.true + + _ILT_rules, s_, t_ = _inverse_laplace_build_rules() + _prep = '' + fsubs = f.subs({s: s_}) + + for s_dom, t_dom, check, prep, fac in _ILT_rules: + if _prep != (prep, fac): + _F = prep(fsubs*fac) + _prep = (prep, fac) + ma = _F.match(s_dom) + if ma: + c = check + if c is not S.true: + args = [x.xreplace(ma) for x in c[0]] + c = c[1](*args) + if c == S.true: + return Heaviside(t)*t_dom.xreplace(ma).subs({t_: t}), S.true + + return None + + +@DEBUG_WRAP +def _inverse_laplace_diff(f, s, t, plane): + """ + Helper function for the class InverseLaplaceTransform. + """ + a = Wild('a', exclude=[s]) + n = Wild('n', exclude=[s]) + g = Wild('g') + ma = f.match(a*Derivative(g, (s, n))) + if ma and ma[n].is_integer: + _debug(' rule: t**n*f(t) o---o (-1)**n*diff(F(s), s, n)') + r, c = _inverse_laplace_transform( + ma[g], s, t, plane, simplify=False, dorational=False) + return (-t)**ma[n]*r, c + return None + + +@DEBUG_WRAP +def _inverse_laplace_time_shift(F, s, t, plane): + """ + Helper function for the class InverseLaplaceTransform. + """ + a = Wild('a', exclude=[s]) + g = Wild('g') + + if not F.has(s): + return F*DiracDelta(t), S.true + if not F.has(exp): + return None + + ma1 = F.match(exp(a*s)) + if ma1: + if ma1[a].is_negative: + _debug(' rule: exp(-a*s) o---o DiracDelta(t-a)') + return DiracDelta(t+ma1[a]), S.true + else: + return InverseLaplaceTransform(F, s, t, plane), S.true + + ma1 = F.match(exp(a*s)*g) + if ma1: + if ma1[a].is_negative: + _debug(' rule: exp(-a*s)*F(s) o---o Heaviside(t-a)*f(t-a)') + return _inverse_laplace_transform( + ma1[g], s, t+ma1[a], plane, simplify=False, dorational=True) + else: + return InverseLaplaceTransform(F, s, t, plane), S.true + return None + + +@DEBUG_WRAP +def _inverse_laplace_freq_shift(F, s, t, plane): + """ + Helper function for the class InverseLaplaceTransform. + """ + if not F.has(s): + return F*DiracDelta(t), S.true + if len(args := F.args) == 1: + a = Wild('a', exclude=[s]) + if (ma := args[0].match(s-a)) and re(ma[a]).is_positive: + _debug(' rule: F(s-a) o---o exp(-a*t)*f(t)') + return ( + exp(-ma[a]*t) * + InverseLaplaceTransform(F.func(s), s, t, plane), S.true) + return None + + +@DEBUG_WRAP +def _inverse_laplace_time_diff(F, s, t, plane): + """ + Helper function for the class InverseLaplaceTransform. + """ + n = Wild('n', exclude=[s]) + g = Wild('g') + + ma1 = F.match(s**n*g) + if ma1 and ma1[n].is_integer and ma1[n].is_positive: + _debug(' rule: s**n*F(s) o---o diff(f(t), t, n)') + r, c = _inverse_laplace_transform( + ma1[g], s, t, plane, simplify=False, dorational=True) + r = r.replace(Heaviside(t), 1) + if r.has(InverseLaplaceTransform): + return diff(r, t, ma1[n]), c + else: + return Heaviside(t)*diff(r, t, ma1[n]), c + return None + + +@DEBUG_WRAP +def _inverse_laplace_irrational(fn, s, t, plane): + """ + Helper function for the class InverseLaplaceTransform. + """ + + a = Wild('a', exclude=[s]) + b = Wild('b', exclude=[s]) + m = Wild('m', exclude=[s]) + n = Wild('n', exclude=[s]) + + result = None + condition = S.true + + fa = fn.as_ordered_factors() + + ma = [x.match((a*s**m+b)**n) for x in fa] + + if None in ma: + return None + + constants = S.One + zeros = [] + poles = [] + rest = [] + + for term in ma: + if term[a] == 0: + constants = constants*term + elif term[n].is_positive: + zeros.append(term) + elif term[n].is_negative: + poles.append(term) + else: + rest.append(term) + + # The code below assumes that the poles are sorted in a specific way: + poles = sorted(poles, key=lambda x: (x[n], x[b] != 0, x[b])) + zeros = sorted(zeros, key=lambda x: (x[n], x[b] != 0, x[b])) + + if len(rest) != 0: + return None + + if len(poles) == 1 and len(zeros) == 0: + if poles[0][n] == -1 and poles[0][m] == S.Half: + # 1/(a0*sqrt(s)+b0) == 1/a0 * 1/(sqrt(s)+b0/a0) + a_ = poles[0][b]/poles[0][a] + k_ = 1/poles[0][a]*constants + if a_.is_positive: + result = ( + k_/sqrt(pi)/sqrt(t) - + k_*a_*exp(a_**2*t)*erfc(a_*sqrt(t))) + _debug(' rule 5.3.4') + elif poles[0][n] == -2 and poles[0][m] == S.Half: + # 1/(a0*sqrt(s)+b0)**2 == 1/a0**2 * 1/(sqrt(s)+b0/a0)**2 + a_sq = poles[0][b]/poles[0][a] + a_ = a_sq**2 + k_ = 1/poles[0][a]**2*constants + if a_sq.is_positive: + result = ( + k_*(1 - 2/sqrt(pi)*sqrt(a_)*sqrt(t) + + (1-2*a_*t)*exp(a_*t)*(erf(sqrt(a_)*sqrt(t))-1))) + _debug(' rule 5.3.10') + elif poles[0][n] == -3 and poles[0][m] == S.Half: + # 1/(a0*sqrt(s)+b0)**3 == 1/a0**3 * 1/(sqrt(s)+b0/a0)**3 + a_ = poles[0][b]/poles[0][a] + k_ = 1/poles[0][a]**3*constants + if a_.is_positive: + result = ( + k_*(2/sqrt(pi)*(a_**2*t+1)*sqrt(t) - + a_*t*exp(a_**2*t)*(2*a_**2*t+3)*erfc(a_*sqrt(t)))) + _debug(' rule 5.3.13') + elif poles[0][n] == -4 and poles[0][m] == S.Half: + # 1/(a0*sqrt(s)+b0)**4 == 1/a0**4 * 1/(sqrt(s)+b0/a0)**4 + a_ = poles[0][b]/poles[0][a] + k_ = 1/poles[0][a]**4*constants/3 + if a_.is_positive: + result = ( + k_*(t*(4*a_**4*t**2+12*a_**2*t+3)*exp(a_**2*t) * + erfc(a_*sqrt(t)) - + 2/sqrt(pi)*a_**3*t**(S(5)/2)*(2*a_**2*t+5))) + _debug(' rule 5.3.16') + elif poles[0][n] == -S.Half and poles[0][m] == 2: + # 1/sqrt(a0*s**2+b0) == 1/sqrt(a0) * 1/sqrt(s**2+b0/a0) + a_ = sqrt(poles[0][b]/poles[0][a]) + k_ = 1/sqrt(poles[0][a])*constants + result = (k_*(besselj(0, a_*t))) + _debug(' rule 5.3.35/44') + + elif len(poles) == 1 and len(zeros) == 1: + if ( + poles[0][n] == -3 and poles[0][m] == S.Half and + zeros[0][n] == S.Half and zeros[0][b] == 0): + # sqrt(az*s)/(ap*sqrt(s+bp)**3) + # == sqrt(az)/ap * sqrt(s)/(sqrt(s+bp)**3) + a_ = poles[0][b] + k_ = sqrt(zeros[0][a])/poles[0][a]*constants + result = ( + k_*(2*a_**4*t**2+5*a_**2*t+1)*exp(a_**2*t) * + erfc(a_*sqrt(t)) - 2/sqrt(pi)*a_*(a_**2*t+2)*sqrt(t)) + _debug(' rule 5.3.14') + if ( + poles[0][n] == -1 and poles[0][m] == 1 and + zeros[0][n] == S.Half and zeros[0][m] == 1): + # sqrt(az*s+bz)/(ap*s+bp) + # == sqrt(az)/ap * (sqrt(s+bz/az)/(s+bp/ap)) + a_ = zeros[0][b]/zeros[0][a] + b_ = poles[0][b]/poles[0][a] + k_ = sqrt(zeros[0][a])/poles[0][a]*constants + result = ( + k_*(exp(-a_*t)/sqrt(t)/sqrt(pi)+sqrt(a_-b_) * + exp(-b_*t)*erf(sqrt(a_-b_)*sqrt(t)))) + _debug(' rule 5.3.22') + + elif len(poles) == 2 and len(zeros) == 0: + if ( + poles[0][n] == -1 and poles[0][m] == 1 and + poles[1][n] == -S.Half and poles[1][m] == 1 and + poles[1][b] == 0): + # 1/((a0*s+b0)*sqrt(a1*s)) + # == 1/(a0*sqrt(a1)) * 1/((s+b0/a0)*sqrt(s)) + a_ = -poles[0][b]/poles[0][a] + k_ = 1/sqrt(poles[1][a])/poles[0][a]*constants + if a_.is_positive: + result = (k_/sqrt(a_)*exp(a_*t)*erf(sqrt(a_)*sqrt(t))) + _debug(' rule 5.3.1') + elif ( + poles[0][n] == -1 and poles[0][m] == 1 and poles[0][b] == 0 and + poles[1][n] == -1 and poles[1][m] == S.Half): + # 1/(a0*s*(a1*sqrt(s)+b1)) + # == 1/(a0*a1) * 1/(s*(sqrt(s)+b1/a1)) + a_ = poles[1][b]/poles[1][a] + k_ = 1/poles[0][a]/poles[1][a]/a_*constants + if a_.is_positive: + result = k_*(1-exp(a_**2*t)*erfc(a_*sqrt(t))) + _debug(' rule 5.3.5') + elif ( + poles[0][n] == -1 and poles[0][m] == S.Half and + poles[1][n] == -S.Half and poles[1][m] == 1 and + poles[1][b] == 0): + # 1/((a0*sqrt(s)+b0)*(sqrt(a1*s)) + # == 1/(a0*sqrt(a1)) * 1/((sqrt(s)+b0/a0)"sqrt(s)) + a_ = poles[0][b]/poles[0][a] + k_ = 1/(poles[0][a]*sqrt(poles[1][a]))*constants + if a_.is_positive: + result = k_*exp(a_**2*t)*erfc(a_*sqrt(t)) + _debug(' rule 5.3.7') + elif ( + poles[0][n] == -S(3)/2 and poles[0][m] == 1 and + poles[0][b] == 0 and poles[1][n] == -1 and + poles[1][m] == S.Half): + # 1/((a0**(3/2)*s**(3/2))*(a1*sqrt(s)+b1)) + # == 1/(a0**(3/2)*a1) 1/((s**(3/2))*(sqrt(s)+b1/a1)) + # Note that Bateman54 5.3 (8) is incorrect; there (sqrt(p)+a) + # should be (sqrt(p)+a)**(-1). + a_ = poles[1][b]/poles[1][a] + k_ = 1/(poles[0][a]**(S(3)/2)*poles[1][a])/a_**2*constants + if a_.is_positive: + result = ( + k_*(2/sqrt(pi)*a_*sqrt(t)+exp(a_**2*t)*erfc(a_*sqrt(t))-1)) + _debug(' rule 5.3.8') + elif ( + poles[0][n] == -2 and poles[0][m] == S.Half and + poles[1][n] == -1 and poles[1][m] == 1 and + poles[1][b] == 0): + # 1/((a0*sqrt(s)+b0)**2*a1*s) + # == 1/a0**2/a1 * 1/(sqrt(s)+b0/a0)**2/s + a_sq = poles[0][b]/poles[0][a] + a_ = a_sq**2 + k_ = 1/poles[0][a]**2/poles[1][a]*constants + if a_sq.is_positive: + result = ( + k_*(1/a_ + (2*t-1/a_)*exp(a_*t)*erfc(sqrt(a_)*sqrt(t)) - + 2/sqrt(pi)/sqrt(a_)*sqrt(t))) + _debug(' rule 5.3.11') + elif ( + poles[0][n] == -2 and poles[0][m] == S.Half and + poles[1][n] == -S.Half and poles[1][m] == 1 and + poles[1][b] == 0): + # 1/((a0*sqrt(s)+b0)**2*sqrt(a1*s)) + # == 1/a0**2/sqrt(a1) * 1/(sqrt(s)+b0/a0)**2/sqrt(s) + a_ = poles[0][b]/poles[0][a] + k_ = 1/poles[0][a]**2/sqrt(poles[1][a])*constants + if a_.is_positive: + result = ( + k_*(2/sqrt(pi)*sqrt(t) - + 2*a_*t*exp(a_**2*t)*erfc(a_*sqrt(t)))) + _debug(' rule 5.3.12') + elif ( + poles[0][n] == -3 and poles[0][m] == S.Half and + poles[1][n] == -S.Half and poles[1][m] == 1 and + poles[1][b] == 0): + # 1 / (sqrt(a1*s)*(a0*sqrt(s+b0)**3)) + # == 1/(sqrt(a1)*a0) * 1/(sqrt(s)*(sqrt(s+b0)**3)) + a_ = poles[0][b] + k_ = constants/sqrt(poles[1][a])/poles[0][a] + result = k_*( + (2*a_**2*t+1)*t*exp(a_**2*t)*erfc(a_*sqrt(t)) - + 2/sqrt(pi)*a_*t**(S(3)/2)) + _debug(' rule 5.3.15') + elif ( + poles[0][n] == -1 and poles[0][m] == 1 and + poles[1][n] == -S.Half and poles[1][m] == 1): + # 1 / ( (a0*s+b0)* sqrt(a1*s+b1) ) + # == 1/(sqrt(a1)*a0) * 1 / ( (s+b0/a0)* sqrt(s+b1/a1) ) + a_ = poles[0][b]/poles[0][a] + b_ = poles[1][b]/poles[1][a] + k_ = constants/sqrt(poles[1][a])/poles[0][a] + result = k_*( + 1/sqrt(b_-a_)*exp(-a_*t)*erf(sqrt(b_-a_)*sqrt(t))) + _debug(' rule 5.3.23') + + elif len(poles) == 2 and len(zeros) == 1: + if ( + poles[0][n] == -1 and poles[0][m] == 1 and + poles[1][n] == -1 and poles[1][m] == S.Half and + zeros[0][n] == S.Half and zeros[0][m] == 1 and + zeros[0][b] == 0): + # sqrt(za0*s)/((a0*s+b0)*(a1*sqrt(s)+b1)) + # == sqrt(za0)/(a0*a1) * s/((s+b0/a0)*(sqrt(s)+b1/a1)) + a_sq = poles[1][b]/poles[1][a] + a_ = a_sq**2 + b_ = -poles[0][b]/poles[0][a] + k_ = sqrt(zeros[0][a])/poles[0][a]/poles[1][a]/(a_-b_)*constants + if a_sq.is_positive and b_.is_positive: + result = k_*( + a_*exp(a_*t)*erfc(sqrt(a_)*sqrt(t)) + + sqrt(a_)*sqrt(b_)*exp(b_*t)*erfc(sqrt(b_)*sqrt(t)) - + b_*exp(b_*t)) + _debug(' rule 5.3.6') + elif ( + poles[0][n] == -1 and poles[0][m] == 1 and + poles[0][b] == 0 and poles[1][n] == -1 and + poles[1][m] == S.Half and zeros[0][n] == 1 and + zeros[0][m] == S.Half): + # (az*sqrt(s)+bz)/(a0*s*(a1*sqrt(s)+b1)) + # == az/a0/a1 * (sqrt(z)+bz/az)/(s*(sqrt(s)+b1/a1)) + a_num = zeros[0][b]/zeros[0][a] + a_ = poles[1][b]/poles[1][a] + if a_+a_num == 0: + k_ = zeros[0][a]/poles[0][a]/poles[1][a]*constants + result = k_*( + 2*exp(a_**2*t)*erfc(a_*sqrt(t))-1) + _debug(' rule 5.3.17') + elif ( + poles[1][n] == -1 and poles[1][m] == 1 and + poles[1][b] == 0 and poles[0][n] == -2 and + poles[0][m] == S.Half and zeros[0][n] == 2 and + zeros[0][m] == S.Half): + # (az*sqrt(s)+bz)**2/(a1*s*(a0*sqrt(s)+b0)**2) + # == az**2/a1/a0**2 * (sqrt(z)+bz/az)**2/(s*(sqrt(s)+b0/a0)**2) + a_num = zeros[0][b]/zeros[0][a] + a_ = poles[0][b]/poles[0][a] + if a_+a_num == 0: + k_ = zeros[0][a]**2/poles[1][a]/poles[0][a]**2*constants + result = k_*( + 1 + 8*a_**2*t*exp(a_**2*t)*erfc(a_*sqrt(t)) - + 8/sqrt(pi)*a_*sqrt(t)) + _debug(' rule 5.3.18') + elif ( + poles[1][n] == -1 and poles[1][m] == 1 and + poles[1][b] == 0 and poles[0][n] == -3 and + poles[0][m] == S.Half and zeros[0][n] == 3 and + zeros[0][m] == S.Half): + # (az*sqrt(s)+bz)**3/(a1*s*(a0*sqrt(s)+b0)**3) + # == az**3/a1/a0**3 * (sqrt(z)+bz/az)**3/(s*(sqrt(s)+b0/a0)**3) + a_num = zeros[0][b]/zeros[0][a] + a_ = poles[0][b]/poles[0][a] + if a_+a_num == 0: + k_ = zeros[0][a]**3/poles[1][a]/poles[0][a]**3*constants + result = k_*( + 2*(8*a_**4*t**2+8*a_**2*t+1)*exp(a_**2*t) * + erfc(a_*sqrt(t))-8/sqrt(pi)*a_*sqrt(t)*(2*a_**2*t+1)-1) + _debug(' rule 5.3.19') + + elif len(poles) == 3 and len(zeros) == 0: + if ( + poles[0][n] == -1 and poles[0][b] == 0 and poles[0][m] == 1 and + poles[1][n] == -1 and poles[1][m] == 1 and + poles[2][n] == -S.Half and poles[2][m] == 1): + # 1/((a0*s)*(a1*s+b1)*sqrt(a2*s)) + # == 1/(a0*a1*sqrt(a2)) * 1/((s)*(s+b1/a1)*sqrt(s)) + a_ = -poles[1][b]/poles[1][a] + k_ = 1/poles[0][a]/poles[1][a]/sqrt(poles[2][a])*constants + if a_.is_positive: + result = k_ * ( + a_**(-S(3)/2) * exp(a_*t) * erf(sqrt(a_)*sqrt(t)) - + 2/a_/sqrt(pi)*sqrt(t)) + _debug(' rule 5.3.2') + elif ( + poles[0][n] == -1 and poles[0][m] == 1 and + poles[1][n] == -1 and poles[1][m] == S.Half and + poles[2][n] == -S.Half and poles[2][m] == 1 and + poles[2][b] == 0): + # 1/((a0*s+b0)*(a1*sqrt(s)+b1)*(sqrt(a2)*sqrt(s))) + # == 1/(a0*a1*sqrt(a2)) * 1/((s+b0/a0)*(sqrt(s)+b1/a1)*sqrt(s)) + a_sq = poles[1][b]/poles[1][a] + a_ = a_sq**2 + b_ = -poles[0][b]/poles[0][a] + k_ = ( + 1/poles[0][a]/poles[1][a]/sqrt(poles[2][a]) / + (sqrt(b_)*(a_-b_))) + if a_sq.is_positive and b_.is_positive: + result = k_ * ( + sqrt(b_)*exp(a_*t)*erfc(sqrt(a_)*sqrt(t)) + + sqrt(a_)*exp(b_*t)*erf(sqrt(b_)*sqrt(t)) - + sqrt(b_)*exp(b_*t)) + _debug(' rule 5.3.9') + + if result is None: + return None + else: + return Heaviside(t)*result, condition + + +@DEBUG_WRAP +def _inverse_laplace_early_prog_rules(F, s, t, plane): + """ + Helper function for the class InverseLaplaceTransform. + """ + prog_rules = [_inverse_laplace_irrational] + + for p_rule in prog_rules: + if (r := p_rule(F, s, t, plane)) is not None: + return r + return None + + +@DEBUG_WRAP +def _inverse_laplace_apply_prog_rules(F, s, t, plane): + """ + Helper function for the class InverseLaplaceTransform. + """ + prog_rules = [_inverse_laplace_time_shift, _inverse_laplace_freq_shift, + _inverse_laplace_time_diff, _inverse_laplace_diff, + _inverse_laplace_irrational] + + for p_rule in prog_rules: + if (r := p_rule(F, s, t, plane)) is not None: + return r + return None + + +@DEBUG_WRAP +def _inverse_laplace_expand(fn, s, t, plane): + """ + Helper function for the class InverseLaplaceTransform. + """ + if fn.is_Add: + return None + r = expand(fn, deep=False) + if r.is_Add: + return _inverse_laplace_transform( + r, s, t, plane, simplify=False, dorational=True) + r = expand_mul(fn) + if r.is_Add: + return _inverse_laplace_transform( + r, s, t, plane, simplify=False, dorational=True) + r = expand(fn) + if r.is_Add: + return _inverse_laplace_transform( + r, s, t, plane, simplify=False, dorational=True) + if fn.is_rational_function(s): + r = fn.apart(s).doit() + if r.is_Add: + return _inverse_laplace_transform( + r, s, t, plane, simplify=False, dorational=True) + return None + + +@DEBUG_WRAP +def _inverse_laplace_rational(fn, s, t, plane, *, simplify): + """ + Helper function for the class InverseLaplaceTransform. + """ + x_ = symbols('x_') + f = fn.apart(s) + terms = Add.make_args(f) + terms_t = [] + conditions = [S.true] + for term in terms: + [n, d] = term.as_numer_denom() + dc = d.as_poly(s).all_coeffs() + dc_lead = dc[0] + dc = [x/dc_lead for x in dc] + nc = [x/dc_lead for x in n.as_poly(s).all_coeffs()] + if len(dc) == 1: + N = len(nc)-1 + for c in enumerate(nc): + r = c[1]*DiracDelta(t, N-c[0]) + terms_t.append(r) + elif len(dc) == 2: + r = nc[0]*exp(-dc[1]*t) + terms_t.append(Heaviside(t)*r) + elif len(dc) == 3: + a = dc[1]/2 + b = (dc[2]-a**2).factor() + if len(nc) == 1: + nc = [S.Zero] + nc + l, m = tuple(nc) + if b == 0: + r = (m*t+l*(1-a*t))*exp(-a*t) + else: + hyp = False + if b.is_negative: + b = -b + hyp = True + b2 = list(roots(x_**2-b, x_).keys())[0] + bs = sqrt(b).simplify() + if hyp: + r = ( + l*exp(-a*t)*cosh(b2*t) + (m-a*l) / + bs*exp(-a*t)*sinh(bs*t)) + else: + r = l*exp(-a*t)*cos(b2*t) + (m-a*l)/bs*exp(-a*t)*sin(bs*t) + terms_t.append(Heaviside(t)*r) + else: + ft, cond = _inverse_laplace_transform( + term, s, t, plane, simplify=simplify, dorational=False) + terms_t.append(ft) + conditions.append(cond) + + result = Add(*terms_t) + if simplify: + result = result.simplify(doit=False) + return result, And(*conditions) + + +@DEBUG_WRAP +def _inverse_laplace_transform(fn, s_, t_, plane, *, simplify, dorational): + """ + Front-end function of the inverse Laplace transform. It tries to apply all + known rules recursively. If everything else fails, it tries to integrate. + """ + terms = Add.make_args(fn) + terms_t = [] + conditions = [] + + for term in terms: + if term.has(exp): + # Simplify expressions with exp() such that time-shifted + # expressions have negative exponents in the numerator instead of + # positive exponents in the numerator and denominator; this is a + # (necessary) trick. It will, for example, convert + # (s**2*exp(2*s) + 4*exp(s) - 4)*exp(-2*s)/(s*(s**2 + 1)) into + # (s**2 + 4*exp(-s) - 4*exp(-2*s))/(s*(s**2 + 1)) + term = term.subs(s_, -s_).together().subs(s_, -s_) + k, f = term.as_independent(s_, as_Add=False) + if ( + dorational and term.is_rational_function(s_) and + (r := _inverse_laplace_rational( + f, s_, t_, plane, simplify=simplify)) + is not None or + (r := _inverse_laplace_apply_simple_rules(f, s_, t_)) + is not None or + (r := _inverse_laplace_early_prog_rules(f, s_, t_, plane)) + is not None or + (r := _inverse_laplace_expand(f, s_, t_, plane)) + is not None or + (r := _inverse_laplace_apply_prog_rules(f, s_, t_, plane)) + is not None): + pass + elif any(undef.has(s_) for undef in f.atoms(AppliedUndef)): + # If there are undefined functions f(t) then integration is + # unlikely to do anything useful so we skip it and given an + # unevaluated LaplaceTransform. + r = (InverseLaplaceTransform(f, s_, t_, plane), S.true) + elif ( + r := _inverse_laplace_transform_integration( + f, s_, t_, plane, simplify=simplify)) is not None: + pass + else: + r = (InverseLaplaceTransform(f, s_, t_, plane), S.true) + (ri_, ci_) = r + terms_t.append(k*ri_) + conditions.append(ci_) + + result = Add(*terms_t) + if simplify: + result = result.simplify(doit=False) + condition = And(*conditions) + + return result, condition + + +class InverseLaplaceTransform(IntegralTransform): + """ + Class representing unevaluated inverse Laplace transforms. + + For usage of this class, see the :class:`IntegralTransform` docstring. + + For how to compute inverse Laplace transforms, see the + :func:`inverse_laplace_transform` docstring. + """ + + _name = 'Inverse Laplace' + _none_sentinel = Dummy('None') + _c = Dummy('c') + + def __new__(cls, F, s, x, plane, **opts): + if plane is None: + plane = InverseLaplaceTransform._none_sentinel + return IntegralTransform.__new__(cls, F, s, x, plane, **opts) + + @property + def fundamental_plane(self): + plane = self.args[3] + if plane is InverseLaplaceTransform._none_sentinel: + plane = None + return plane + + def _compute_transform(self, F, s, t, **hints): + return _inverse_laplace_transform_integration( + F, s, t, self.fundamental_plane, **hints) + + def _as_integral(self, F, s, t): + c = self.__class__._c + return ( + Integral(exp(s*t)*F, (s, c - S.ImaginaryUnit*S.Infinity, + c + S.ImaginaryUnit*S.Infinity)) / + (2*S.Pi*S.ImaginaryUnit)) + + def doit(self, **hints): + """ + Try to evaluate the transform in closed form. + + Explanation + =========== + + Standard hints are the following: + - ``noconds``: if True, do not return convergence conditions. The + default setting is `True`. + - ``simplify``: if True, it simplifies the final result. The + default setting is `False`. + """ + _noconds = hints.get('noconds', True) + _simplify = hints.get('simplify', False) + + debugf('[ILT doit] (%s, %s, %s)', (self.function, + self.function_variable, + self.transform_variable)) + + s_ = self.function_variable + t_ = self.transform_variable + fn = self.function + plane = self.fundamental_plane + + r = _inverse_laplace_transform( + fn, s_, t_, plane, simplify=_simplify, dorational=True) + + if _noconds: + return r[0] + else: + return r + + +def inverse_laplace_transform(F, s, t, plane=None, **hints): + r""" + Compute the inverse Laplace transform of `F(s)`, defined as + + .. math :: + f(t) = \frac{1}{2\pi i} \int_{c-i\infty}^{c+i\infty} e^{st} + F(s) \mathrm{d}s, + + for `c` so large that `F(s)` has no singularites in the + half-plane `\operatorname{Re}(s) > c-\epsilon`. + + Explanation + =========== + + The plane can be specified by + argument ``plane``, but will be inferred if passed as None. + + Under certain regularity conditions, this recovers `f(t)` from its + Laplace Transform `F(s)`, for non-negative `t`, and vice + versa. + + If the integral cannot be computed in closed form, this function returns + an unevaluated :class:`InverseLaplaceTransform` object. + + Note that this function will always assume `t` to be real, + regardless of the SymPy assumption on `t`. + + For a description of possible hints, refer to the docstring of + :func:`sympy.integrals.transforms.IntegralTransform.doit`. + + Examples + ======== + + >>> from sympy import inverse_laplace_transform, exp, Symbol + >>> from sympy.abc import s, t + >>> a = Symbol('a', positive=True) + >>> inverse_laplace_transform(exp(-a*s)/s, s, t) + Heaviside(-a + t) + + See Also + ======== + + laplace_transform + hankel_transform, inverse_hankel_transform + """ + _noconds = hints.get('noconds', True) + _simplify = hints.get('simplify', False) + + if isinstance(F, MatrixBase) and hasattr(F, 'applyfunc'): + return F.applyfunc( + lambda Fij: inverse_laplace_transform(Fij, s, t, plane, **hints)) + + r, c = InverseLaplaceTransform(F, s, t, plane).doit( + noconds=False, simplify=_simplify) + + if _noconds: + return r + else: + return r, c + + +def _fast_inverse_laplace(e, s, t): + """Fast inverse Laplace transform of rational function including RootSum""" + a, b, n = symbols('a, b, n', cls=Wild, exclude=[s]) + + def _ilt(e): + if not e.has(s): + return e + elif e.is_Add: + return _ilt_add(e) + elif e.is_Mul: + return _ilt_mul(e) + elif e.is_Pow: + return _ilt_pow(e) + elif isinstance(e, RootSum): + return _ilt_rootsum(e) + else: + raise NotImplementedError + + def _ilt_add(e): + return e.func(*map(_ilt, e.args)) + + def _ilt_mul(e): + coeff, expr = e.as_independent(s) + if expr.is_Mul: + raise NotImplementedError + return coeff * _ilt(expr) + + def _ilt_pow(e): + match = e.match((a*s + b)**n) + if match is not None: + nm, am, bm = match[n], match[a], match[b] + if nm.is_Integer and nm < 0: + return t**(-nm-1)*exp(-(bm/am)*t)/(am**-nm*gamma(-nm)) + if nm == 1: + return exp(-(bm/am)*t) / am + raise NotImplementedError + + def _ilt_rootsum(e): + expr = e.fun.expr + [variable] = e.fun.variables + return RootSum(e.poly, Lambda(variable, together(_ilt(expr)))) + + return _ilt(e) diff --git a/.venv/lib/python3.13/site-packages/sympy/integrals/manualintegrate.py b/.venv/lib/python3.13/site-packages/sympy/integrals/manualintegrate.py new file mode 100644 index 0000000000000000000000000000000000000000..2908fb33003ba9c22da47e550edcfea2b41a26a0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/integrals/manualintegrate.py @@ -0,0 +1,2174 @@ +"""Integration method that emulates by-hand techniques. + +This module also provides functionality to get the steps used to evaluate a +particular integral, in the ``integral_steps`` function. This will return +nested ``Rule`` s representing the integration rules used. + +Each ``Rule`` class represents a (maybe parametrized) integration rule, e.g. +``SinRule`` for integrating ``sin(x)`` and ``ReciprocalSqrtQuadraticRule`` +for integrating ``1/sqrt(a+b*x+c*x**2)``. The ``eval`` method returns the +integration result. + +The ``manualintegrate`` function computes the integral by calling ``eval`` +on the rule returned by ``integral_steps``. + +The integrator can be extended with new heuristics and evaluation +techniques. To do so, extend the ``Rule`` class, implement ``eval`` method, +then write a function that accepts an ``IntegralInfo`` object and returns +either a ``Rule`` instance or ``None``. If the new technique requires a new +match, add the key and call to the antiderivative function to integral_steps. +To enable simple substitutions, add the match to find_substitutions. + +""" + +from __future__ import annotations +from typing import NamedTuple, Type, Callable, Sequence +from abc import ABC, abstractmethod +from dataclasses import dataclass +from collections import defaultdict +from collections.abc import Mapping + +from sympy.core.add import Add +from sympy.core.cache import cacheit +from sympy.core.containers import Dict +from sympy.core.expr import Expr +from sympy.core.function import Derivative +from sympy.core.logic import fuzzy_not +from sympy.core.mul import Mul +from sympy.core.numbers import Integer, Number, E +from sympy.core.power import Pow +from sympy.core.relational import Eq, Ne, Boolean +from sympy.core.singleton import S +from sympy.core.symbol import Dummy, Symbol, Wild +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.hyperbolic import (HyperbolicFunction, csch, + cosh, coth, sech, sinh, tanh, asinh) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (TrigonometricFunction, + cos, sin, tan, cot, csc, sec, acos, asin, atan, acot, acsc, asec) +from sympy.functions.special.delta_functions import Heaviside, DiracDelta +from sympy.functions.special.error_functions import (erf, erfi, fresnelc, + fresnels, Ci, Chi, Si, Shi, Ei, li) +from sympy.functions.special.gamma_functions import uppergamma +from sympy.functions.special.elliptic_integrals import elliptic_e, elliptic_f +from sympy.functions.special.polynomials import (chebyshevt, chebyshevu, + legendre, hermite, laguerre, assoc_laguerre, gegenbauer, jacobi, + OrthogonalPolynomial) +from sympy.functions.special.zeta_functions import polylog +from .integrals import Integral +from sympy.logic.boolalg import And +from sympy.ntheory.factor_ import primefactors +from sympy.polys.polytools import degree, lcm_list, gcd_list, Poly +from sympy.simplify.radsimp import fraction +from sympy.simplify.simplify import simplify +from sympy.solvers.solvers import solve +from sympy.strategies.core import switch, do_one, null_safe, condition +from sympy.utilities.iterables import iterable +from sympy.utilities.misc import debug + + +@dataclass +class Rule(ABC): + integrand: Expr + variable: Symbol + + @abstractmethod + def eval(self) -> Expr: + pass + + @abstractmethod + def contains_dont_know(self) -> bool: + pass + + +@dataclass +class AtomicRule(Rule, ABC): + """A simple rule that does not depend on other rules""" + def contains_dont_know(self) -> bool: + return False + + +@dataclass +class ConstantRule(AtomicRule): + """integrate(a, x) -> a*x""" + def eval(self) -> Expr: + return self.integrand * self.variable + + +@dataclass +class ConstantTimesRule(Rule): + """integrate(a*f(x), x) -> a*integrate(f(x), x)""" + constant: Expr + other: Expr + substep: Rule + + def eval(self) -> Expr: + return self.constant * self.substep.eval() + + def contains_dont_know(self) -> bool: + return self.substep.contains_dont_know() + + +@dataclass +class PowerRule(AtomicRule): + """integrate(x**a, x)""" + base: Expr + exp: Expr + + def eval(self) -> Expr: + return Piecewise( + ((self.base**(self.exp + 1))/(self.exp + 1), Ne(self.exp, -1)), + (log(self.base), True), + ) + + +@dataclass +class NestedPowRule(AtomicRule): + """integrate((x**a)**b, x)""" + base: Expr + exp: Expr + + def eval(self) -> Expr: + m = self.base * self.integrand + return Piecewise((m / (self.exp + 1), Ne(self.exp, -1)), + (m * log(self.base), True)) + + +@dataclass +class AddRule(Rule): + """integrate(f(x) + g(x), x) -> integrate(f(x), x) + integrate(g(x), x)""" + substeps: list[Rule] + + def eval(self) -> Expr: + return Add(*(substep.eval() for substep in self.substeps)) + + def contains_dont_know(self) -> bool: + return any(substep.contains_dont_know() for substep in self.substeps) + + +@dataclass +class URule(Rule): + """integrate(f(g(x))*g'(x), x) -> integrate(f(u), u), u = g(x)""" + u_var: Symbol + u_func: Expr + substep: Rule + + def eval(self) -> Expr: + result = self.substep.eval() + if self.u_func.is_Pow: + base, exp_ = self.u_func.as_base_exp() + if exp_ == -1: + # avoid needless -log(1/x) from substitution + result = result.subs(log(self.u_var), -log(base)) + return result.subs(self.u_var, self.u_func) + + def contains_dont_know(self) -> bool: + return self.substep.contains_dont_know() + + +@dataclass +class PartsRule(Rule): + """integrate(u(x)*v'(x), x) -> u(x)*v(x) - integrate(u'(x)*v(x), x)""" + u: Symbol + dv: Expr + v_step: Rule + second_step: Rule | None # None when is a substep of CyclicPartsRule + + def eval(self) -> Expr: + assert self.second_step is not None + v = self.v_step.eval() + return self.u * v - self.second_step.eval() + + def contains_dont_know(self) -> bool: + return self.v_step.contains_dont_know() or ( + self.second_step is not None and self.second_step.contains_dont_know()) + + +@dataclass +class CyclicPartsRule(Rule): + """Apply PartsRule multiple times to integrate exp(x)*sin(x)""" + parts_rules: list[PartsRule] + coefficient: Expr + + def eval(self) -> Expr: + result = [] + sign = 1 + for rule in self.parts_rules: + result.append(sign * rule.u * rule.v_step.eval()) + sign *= -1 + return Add(*result) / (1 - self.coefficient) + + def contains_dont_know(self) -> bool: + return any(substep.contains_dont_know() for substep in self.parts_rules) + + +@dataclass +class TrigRule(AtomicRule, ABC): + pass + + +@dataclass +class SinRule(TrigRule): + """integrate(sin(x), x) -> -cos(x)""" + def eval(self) -> Expr: + return -cos(self.variable) + + +@dataclass +class CosRule(TrigRule): + """integrate(cos(x), x) -> sin(x)""" + def eval(self) -> Expr: + return sin(self.variable) + + +@dataclass +class SecTanRule(TrigRule): + """integrate(sec(x)*tan(x), x) -> sec(x)""" + def eval(self) -> Expr: + return sec(self.variable) + + +@dataclass +class CscCotRule(TrigRule): + """integrate(csc(x)*cot(x), x) -> -csc(x)""" + def eval(self) -> Expr: + return -csc(self.variable) + + +@dataclass +class Sec2Rule(TrigRule): + """integrate(sec(x)**2, x) -> tan(x)""" + def eval(self) -> Expr: + return tan(self.variable) + + +@dataclass +class Csc2Rule(TrigRule): + """integrate(csc(x)**2, x) -> -cot(x)""" + def eval(self) -> Expr: + return -cot(self.variable) + + +@dataclass +class HyperbolicRule(AtomicRule, ABC): + pass + + +@dataclass +class SinhRule(HyperbolicRule): + """integrate(sinh(x), x) -> cosh(x)""" + def eval(self) -> Expr: + return cosh(self.variable) + + +@dataclass +class CoshRule(HyperbolicRule): + """integrate(cosh(x), x) -> sinh(x)""" + def eval(self): + return sinh(self.variable) + + +@dataclass +class ExpRule(AtomicRule): + """integrate(a**x, x) -> a**x/ln(a)""" + base: Expr + exp: Expr + + def eval(self) -> Expr: + return self.integrand / log(self.base) + + +@dataclass +class ReciprocalRule(AtomicRule): + """integrate(1/x, x) -> ln(x)""" + base: Expr + + def eval(self) -> Expr: + return log(self.base) + + +@dataclass +class ArcsinRule(AtomicRule): + """integrate(1/sqrt(1-x**2), x) -> asin(x)""" + def eval(self) -> Expr: + return asin(self.variable) + + +@dataclass +class ArcsinhRule(AtomicRule): + """integrate(1/sqrt(1+x**2), x) -> asin(x)""" + def eval(self) -> Expr: + return asinh(self.variable) + + +@dataclass +class ReciprocalSqrtQuadraticRule(AtomicRule): + """integrate(1/sqrt(a+b*x+c*x**2), x) -> log(2*sqrt(c)*sqrt(a+b*x+c*x**2)+b+2*c*x)/sqrt(c)""" + a: Expr + b: Expr + c: Expr + + def eval(self) -> Expr: + a, b, c, x = self.a, self.b, self.c, self.variable + return log(2*sqrt(c)*sqrt(a+b*x+c*x**2)+b+2*c*x)/sqrt(c) + + +@dataclass +class SqrtQuadraticDenomRule(AtomicRule): + """integrate(poly(x)/sqrt(a+b*x+c*x**2), x)""" + a: Expr + b: Expr + c: Expr + coeffs: list[Expr] + + def eval(self) -> Expr: + a, b, c, coeffs, x = self.a, self.b, self.c, self.coeffs.copy(), self.variable + # Integrate poly/sqrt(a+b*x+c*x**2) using recursion. + # coeffs are coefficients of the polynomial. + # Let I_n = x**n/sqrt(a+b*x+c*x**2), then + # I_n = A * x**(n-1)*sqrt(a+b*x+c*x**2) - B * I_{n-1} - C * I_{n-2} + # where A = 1/(n*c), B = (2*n-1)*b/(2*n*c), C = (n-1)*a/(n*c) + # See https://github.com/sympy/sympy/pull/23608 for proof. + result_coeffs = [] + coeffs = coeffs.copy() + for i in range(len(coeffs)-2): + n = len(coeffs)-1-i + coeff = coeffs[i]/(c*n) + result_coeffs.append(coeff) + coeffs[i+1] -= (2*n-1)*b/2*coeff + coeffs[i+2] -= (n-1)*a*coeff + d, e = coeffs[-1], coeffs[-2] + s = sqrt(a+b*x+c*x**2) + constant = d-b*e/(2*c) + if constant == 0: + I0 = 0 + else: + step = inverse_trig_rule(IntegralInfo(1/s, x), degenerate=False) + I0 = constant*step.eval() + return Add(*(result_coeffs[i]*x**(len(coeffs)-2-i) + for i in range(len(result_coeffs))), e/c)*s + I0 + + +@dataclass +class SqrtQuadraticRule(AtomicRule): + """integrate(sqrt(a+b*x+c*x**2), x)""" + a: Expr + b: Expr + c: Expr + + def eval(self) -> Expr: + step = sqrt_quadratic_rule(IntegralInfo(self.integrand, self.variable), degenerate=False) + return step.eval() + + +@dataclass +class AlternativeRule(Rule): + """Multiple ways to do integration.""" + alternatives: list[Rule] + + def eval(self) -> Expr: + return self.alternatives[0].eval() + + def contains_dont_know(self) -> bool: + return any(substep.contains_dont_know() for substep in self.alternatives) + + +@dataclass +class DontKnowRule(Rule): + """Leave the integral as is.""" + def eval(self) -> Expr: + return Integral(self.integrand, self.variable) + + def contains_dont_know(self) -> bool: + return True + + +@dataclass +class DerivativeRule(AtomicRule): + """integrate(f'(x), x) -> f(x)""" + def eval(self) -> Expr: + assert isinstance(self.integrand, Derivative) + variable_count = list(self.integrand.variable_count) + for i, (var, count) in enumerate(variable_count): + if var == self.variable: + variable_count[i] = (var, count - 1) + break + return Derivative(self.integrand.expr, *variable_count) + + +@dataclass +class RewriteRule(Rule): + """Rewrite integrand to another form that is easier to handle.""" + rewritten: Expr + substep: Rule + + def eval(self) -> Expr: + return self.substep.eval() + + def contains_dont_know(self) -> bool: + return self.substep.contains_dont_know() + + +@dataclass +class CompleteSquareRule(RewriteRule): + """Rewrite a+b*x+c*x**2 to a-b**2/(4*c) + c*(x+b/(2*c))**2""" + pass + + +@dataclass +class PiecewiseRule(Rule): + subfunctions: Sequence[tuple[Rule, bool | Boolean]] + + def eval(self) -> Expr: + return Piecewise(*[(substep.eval(), cond) + for substep, cond in self.subfunctions]) + + def contains_dont_know(self) -> bool: + return any(substep.contains_dont_know() for substep, _ in self.subfunctions) + + +@dataclass +class HeavisideRule(Rule): + harg: Expr + ibnd: Expr + substep: Rule + + def eval(self) -> Expr: + # If we are integrating over x and the integrand has the form + # Heaviside(m*x+b)*g(x) == Heaviside(harg)*g(symbol) + # then there needs to be continuity at -b/m == ibnd, + # so we subtract the appropriate term. + result = self.substep.eval() + return Heaviside(self.harg) * (result - result.subs(self.variable, self.ibnd)) + + def contains_dont_know(self) -> bool: + return self.substep.contains_dont_know() + + +@dataclass +class DiracDeltaRule(AtomicRule): + n: Expr + a: Expr + b: Expr + + def eval(self) -> Expr: + n, a, b, x = self.n, self.a, self.b, self.variable + if n == 0: + return Heaviside(a+b*x)/b + return DiracDelta(a+b*x, n-1)/b + + +@dataclass +class TrigSubstitutionRule(Rule): + theta: Expr + func: Expr + rewritten: Expr + substep: Rule + restriction: bool | Boolean + + def eval(self) -> Expr: + theta, func, x = self.theta, self.func, self.variable + func = func.subs(sec(theta), 1/cos(theta)) + func = func.subs(csc(theta), 1/sin(theta)) + func = func.subs(cot(theta), 1/tan(theta)) + + trig_function = list(func.find(TrigonometricFunction)) + assert len(trig_function) == 1 + trig_function = trig_function[0] + relation = solve(x - func, trig_function) + assert len(relation) == 1 + numer, denom = fraction(relation[0]) + + if isinstance(trig_function, sin): + opposite = numer + hypotenuse = denom + adjacent = sqrt(denom**2 - numer**2) + inverse = asin(relation[0]) + elif isinstance(trig_function, cos): + adjacent = numer + hypotenuse = denom + opposite = sqrt(denom**2 - numer**2) + inverse = acos(relation[0]) + else: # tan + opposite = numer + adjacent = denom + hypotenuse = sqrt(denom**2 + numer**2) + inverse = atan(relation[0]) + + substitution = [ + (sin(theta), opposite/hypotenuse), + (cos(theta), adjacent/hypotenuse), + (tan(theta), opposite/adjacent), + (theta, inverse) + ] + return Piecewise( + (self.substep.eval().subs(substitution).trigsimp(), self.restriction) # type: ignore + ) + + def contains_dont_know(self) -> bool: + return self.substep.contains_dont_know() + + +@dataclass +class ArctanRule(AtomicRule): + """integrate(a/(b*x**2+c), x) -> a/b / sqrt(c/b) * atan(x/sqrt(c/b))""" + a: Expr + b: Expr + c: Expr + + def eval(self) -> Expr: + a, b, c, x = self.a, self.b, self.c, self.variable + return a/b / sqrt(c/b) * atan(x/sqrt(c/b)) + + +@dataclass +class OrthogonalPolyRule(AtomicRule, ABC): + n: Expr + + +@dataclass +class JacobiRule(OrthogonalPolyRule): + a: Expr + b: Expr + + def eval(self) -> Expr: + n, a, b, x = self.n, self.a, self.b, self.variable + return Piecewise( + (2*jacobi(n + 1, a - 1, b - 1, x)/(n + a + b), Ne(n + a + b, 0)), + (x, Eq(n, 0)), + ((a + b + 2)*x**2/4 + (a - b)*x/2, Eq(n, 1))) + + +@dataclass +class GegenbauerRule(OrthogonalPolyRule): + a: Expr + + def eval(self) -> Expr: + n, a, x = self.n, self.a, self.variable + return Piecewise( + (gegenbauer(n + 1, a - 1, x)/(2*(a - 1)), Ne(a, 1)), + (chebyshevt(n + 1, x)/(n + 1), Ne(n, -1)), + (S.Zero, True)) + + +@dataclass +class ChebyshevTRule(OrthogonalPolyRule): + def eval(self) -> Expr: + n, x = self.n, self.variable + return Piecewise( + ((chebyshevt(n + 1, x)/(n + 1) - + chebyshevt(n - 1, x)/(n - 1))/2, Ne(Abs(n), 1)), + (x**2/2, True)) + + +@dataclass +class ChebyshevURule(OrthogonalPolyRule): + def eval(self) -> Expr: + n, x = self.n, self.variable + return Piecewise( + (chebyshevt(n + 1, x)/(n + 1), Ne(n, -1)), + (S.Zero, True)) + + +@dataclass +class LegendreRule(OrthogonalPolyRule): + def eval(self) -> Expr: + n, x = self.n, self.variable + return(legendre(n + 1, x) - legendre(n - 1, x))/(2*n + 1) + + +@dataclass +class HermiteRule(OrthogonalPolyRule): + def eval(self) -> Expr: + n, x = self.n, self.variable + return hermite(n + 1, x)/(2*(n + 1)) + + +@dataclass +class LaguerreRule(OrthogonalPolyRule): + def eval(self) -> Expr: + n, x = self.n, self.variable + return laguerre(n, x) - laguerre(n + 1, x) + + +@dataclass +class AssocLaguerreRule(OrthogonalPolyRule): + a: Expr + + def eval(self) -> Expr: + return -assoc_laguerre(self.n + 1, self.a - 1, self.variable) + + +@dataclass +class IRule(AtomicRule, ABC): + a: Expr + b: Expr + + +@dataclass +class CiRule(IRule): + def eval(self) -> Expr: + a, b, x = self.a, self.b, self.variable + return cos(b)*Ci(a*x) - sin(b)*Si(a*x) + + +@dataclass +class ChiRule(IRule): + def eval(self) -> Expr: + a, b, x = self.a, self.b, self.variable + return cosh(b)*Chi(a*x) + sinh(b)*Shi(a*x) + + +@dataclass +class EiRule(IRule): + def eval(self) -> Expr: + a, b, x = self.a, self.b, self.variable + return exp(b)*Ei(a*x) + + +@dataclass +class SiRule(IRule): + def eval(self) -> Expr: + a, b, x = self.a, self.b, self.variable + return sin(b)*Ci(a*x) + cos(b)*Si(a*x) + + +@dataclass +class ShiRule(IRule): + def eval(self) -> Expr: + a, b, x = self.a, self.b, self.variable + return sinh(b)*Chi(a*x) + cosh(b)*Shi(a*x) + + +@dataclass +class LiRule(IRule): + def eval(self) -> Expr: + a, b, x = self.a, self.b, self.variable + return li(a*x + b)/a + + +@dataclass +class ErfRule(AtomicRule): + a: Expr + b: Expr + c: Expr + + def eval(self) -> Expr: + a, b, c, x = self.a, self.b, self.c, self.variable + if a.is_extended_real: + return Piecewise( + (sqrt(S.Pi)/sqrt(-a)/2 * exp(c - b**2/(4*a)) * + erf((-2*a*x - b)/(2*sqrt(-a))), a < 0), + (sqrt(S.Pi)/sqrt(a)/2 * exp(c - b**2/(4*a)) * + erfi((2*a*x + b)/(2*sqrt(a))), True)) + return sqrt(S.Pi)/sqrt(a)/2 * exp(c - b**2/(4*a)) * \ + erfi((2*a*x + b)/(2*sqrt(a))) + + +@dataclass +class FresnelCRule(AtomicRule): + a: Expr + b: Expr + c: Expr + + def eval(self) -> Expr: + a, b, c, x = self.a, self.b, self.c, self.variable + return sqrt(S.Pi)/sqrt(2*a) * ( + cos(b**2/(4*a) - c)*fresnelc((2*a*x + b)/sqrt(2*a*S.Pi)) + + sin(b**2/(4*a) - c)*fresnels((2*a*x + b)/sqrt(2*a*S.Pi))) + + +@dataclass +class FresnelSRule(AtomicRule): + a: Expr + b: Expr + c: Expr + + def eval(self) -> Expr: + a, b, c, x = self.a, self.b, self.c, self.variable + return sqrt(S.Pi)/sqrt(2*a) * ( + cos(b**2/(4*a) - c)*fresnels((2*a*x + b)/sqrt(2*a*S.Pi)) - + sin(b**2/(4*a) - c)*fresnelc((2*a*x + b)/sqrt(2*a*S.Pi))) + + +@dataclass +class PolylogRule(AtomicRule): + a: Expr + b: Expr + + def eval(self) -> Expr: + return polylog(self.b + 1, self.a * self.variable) + + +@dataclass +class UpperGammaRule(AtomicRule): + a: Expr + e: Expr + + def eval(self) -> Expr: + a, e, x = self.a, self.e, self.variable + return x**e * (-a*x)**(-e) * uppergamma(e + 1, -a*x)/a + + +@dataclass +class EllipticFRule(AtomicRule): + a: Expr + d: Expr + + def eval(self) -> Expr: + return elliptic_f(self.variable, self.d/self.a)/sqrt(self.a) + + +@dataclass +class EllipticERule(AtomicRule): + a: Expr + d: Expr + + def eval(self) -> Expr: + return elliptic_e(self.variable, self.d/self.a)*sqrt(self.a) + + +class IntegralInfo(NamedTuple): + integrand: Expr + symbol: Symbol + + +def manual_diff(f, symbol): + """Derivative of f in form expected by find_substitutions + + SymPy's derivatives for some trig functions (like cot) are not in a form + that works well with finding substitutions; this replaces the + derivatives for those particular forms with something that works better. + + """ + if f.args: + arg = f.args[0] + if isinstance(f, tan): + return arg.diff(symbol) * sec(arg)**2 + elif isinstance(f, cot): + return -arg.diff(symbol) * csc(arg)**2 + elif isinstance(f, sec): + return arg.diff(symbol) * sec(arg) * tan(arg) + elif isinstance(f, csc): + return -arg.diff(symbol) * csc(arg) * cot(arg) + elif isinstance(f, Add): + return sum(manual_diff(arg, symbol) for arg in f.args) + elif isinstance(f, Mul): + if len(f.args) == 2 and isinstance(f.args[0], Number): + return f.args[0] * manual_diff(f.args[1], symbol) + return f.diff(symbol) + +def manual_subs(expr, *args): + """ + A wrapper for `expr.subs(*args)` with additional logic for substitution + of invertible functions. + """ + if len(args) == 1: + sequence = args[0] + if isinstance(sequence, (Dict, Mapping)): + sequence = sequence.items() + elif not iterable(sequence): + raise ValueError("Expected an iterable of (old, new) pairs") + elif len(args) == 2: + sequence = [args] + else: + raise ValueError("subs accepts either 1 or 2 arguments") + + new_subs = [] + for old, new in sequence: + if isinstance(old, log): + # If log(x) = y, then exp(a*log(x)) = exp(a*y) + # that is, x**a = exp(a*y). Replace nontrivial powers of x + # before subs turns them into `exp(y)**a`, but + # do not replace x itself yet, to avoid `log(exp(y))`. + x0 = old.args[0] + expr = expr.replace(lambda x: x.is_Pow and x.base == x0, + lambda x: exp(x.exp*new)) + new_subs.append((x0, exp(new))) + + return expr.subs(list(sequence) + new_subs) + +# Method based on that on SIN, described in "Symbolic Integration: The +# Stormy Decade" + +inverse_trig_functions = (atan, asin, acos, acot, acsc, asec) + + +def find_substitutions(integrand, symbol, u_var): + results = [] + + def test_subterm(u, u_diff): + if u_diff == 0: + return False + substituted = integrand / u_diff + debug("substituted: {}, u: {}, u_var: {}".format(substituted, u, u_var)) + substituted = manual_subs(substituted, u, u_var).cancel() + + if substituted.has_free(symbol): + return False + # avoid increasing the degree of a rational function + if integrand.is_rational_function(symbol) and substituted.is_rational_function(u_var): + deg_before = max(degree(t, symbol) for t in integrand.as_numer_denom()) + deg_after = max(degree(t, u_var) for t in substituted.as_numer_denom()) + if deg_after > deg_before: + return False + return substituted.as_independent(u_var, as_Add=False) + + def exp_subterms(term: Expr): + linear_coeffs = [] + terms = [] + n = Wild('n', properties=[lambda n: n.is_Integer]) + for exp_ in term.find(exp): + arg = exp_.args[0] + if symbol not in arg.free_symbols: + continue + match = arg.match(n*symbol) + if match: + linear_coeffs.append(match[n]) + else: + terms.append(exp_) + if linear_coeffs: + terms.append(exp(gcd_list(linear_coeffs)*symbol)) + return terms + + def possible_subterms(term): + if isinstance(term, (TrigonometricFunction, HyperbolicFunction, + *inverse_trig_functions, + exp, log, Heaviside)): + return [term.args[0]] + elif isinstance(term, (chebyshevt, chebyshevu, + legendre, hermite, laguerre)): + return [term.args[1]] + elif isinstance(term, (gegenbauer, assoc_laguerre)): + return [term.args[2]] + elif isinstance(term, jacobi): + return [term.args[3]] + elif isinstance(term, Mul): + r = [] + for u in term.args: + r.append(u) + r.extend(possible_subterms(u)) + return r + elif isinstance(term, Pow): + r = [arg for arg in term.args if arg.has(symbol)] + if term.exp.is_Integer: + r.extend([term.base**d for d in primefactors(term.exp) + if 1 < d < abs(term.args[1])]) + if term.base.is_Add: + r.extend([t for t in possible_subterms(term.base) + if t.is_Pow]) + return r + elif isinstance(term, Add): + r = [] + for arg in term.args: + r.append(arg) + r.extend(possible_subterms(arg)) + return r + return [] + + for u in list(dict.fromkeys(possible_subterms(integrand) + exp_subterms(integrand))): + if u == symbol: + continue + u_diff = manual_diff(u, symbol) + new_integrand = test_subterm(u, u_diff) + if new_integrand is not False: + constant, new_integrand = new_integrand + if new_integrand == integrand.subs(symbol, u_var): + continue + substitution = (u, constant, new_integrand) + if substitution not in results: + results.append(substitution) + + return results + +def rewriter(condition, rewrite): + """Strategy that rewrites an integrand.""" + def _rewriter(integral): + integrand, symbol = integral + debug("Integral: {} is rewritten with {} on symbol: {}".format(integrand, rewrite, symbol)) + if condition(*integral): + rewritten = rewrite(*integral) + if rewritten != integrand: + substep = integral_steps(rewritten, symbol) + if not isinstance(substep, DontKnowRule) and substep: + return RewriteRule(integrand, symbol, rewritten, substep) + return _rewriter + +def proxy_rewriter(condition, rewrite): + """Strategy that rewrites an integrand based on some other criteria.""" + def _proxy_rewriter(criteria): + criteria, integral = criteria + integrand, symbol = integral + debug("Integral: {} is rewritten with {} on symbol: {} and criteria: {}".format(integrand, rewrite, symbol, criteria)) + args = criteria + list(integral) + if condition(*args): + rewritten = rewrite(*args) + if rewritten != integrand: + return RewriteRule(integrand, symbol, rewritten, integral_steps(rewritten, symbol)) + return _proxy_rewriter + +def multiplexer(conditions): + """Apply the rule that matches the condition, else None""" + def multiplexer_rl(expr): + for key, rule in conditions.items(): + if key(expr): + return rule(expr) + return multiplexer_rl + +def alternatives(*rules): + """Strategy that makes an AlternativeRule out of multiple possible results.""" + def _alternatives(integral): + alts = [] + count = 0 + debug("List of Alternative Rules") + for rule in rules: + count = count + 1 + debug("Rule {}: {}".format(count, rule)) + + result = rule(integral) + if (result and not isinstance(result, DontKnowRule) and + result != integral and result not in alts): + alts.append(result) + if len(alts) == 1: + return alts[0] + elif alts: + doable = [rule for rule in alts if not rule.contains_dont_know()] + if doable: + return AlternativeRule(*integral, doable) + else: + return AlternativeRule(*integral, alts) + return _alternatives + +def constant_rule(integral): + return ConstantRule(*integral) + +def power_rule(integral): + integrand, symbol = integral + base, expt = integrand.as_base_exp() + + if symbol not in expt.free_symbols and isinstance(base, Symbol): + if simplify(expt + 1) == 0: + return ReciprocalRule(integrand, symbol, base) + return PowerRule(integrand, symbol, base, expt) + elif symbol not in base.free_symbols and isinstance(expt, Symbol): + rule = ExpRule(integrand, symbol, base, expt) + + if fuzzy_not(log(base).is_zero): + return rule + elif log(base).is_zero: + return ConstantRule(1, symbol) + + return PiecewiseRule(integrand, symbol, [ + (rule, Ne(log(base), 0)), + (ConstantRule(1, symbol), True) + ]) + +def exp_rule(integral): + integrand, symbol = integral + if isinstance(integrand.args[0], Symbol): + return ExpRule(integrand, symbol, E, integrand.args[0]) + + +def orthogonal_poly_rule(integral): + orthogonal_poly_classes = { + jacobi: JacobiRule, + gegenbauer: GegenbauerRule, + chebyshevt: ChebyshevTRule, + chebyshevu: ChebyshevURule, + legendre: LegendreRule, + hermite: HermiteRule, + laguerre: LaguerreRule, + assoc_laguerre: AssocLaguerreRule + } + orthogonal_poly_var_index = { + jacobi: 3, + gegenbauer: 2, + assoc_laguerre: 2 + } + integrand, symbol = integral + for klass in orthogonal_poly_classes: + if isinstance(integrand, klass): + var_index = orthogonal_poly_var_index.get(klass, 1) + if (integrand.args[var_index] is symbol and not + any(v.has(symbol) for v in integrand.args[:var_index])): + return orthogonal_poly_classes[klass](integrand, symbol, *integrand.args[:var_index]) + + +_special_function_patterns: list[tuple[Type, Expr, Callable | None, tuple]] = [] +_wilds = [] +_symbol = Dummy('x') + + +def special_function_rule(integral): + integrand, symbol = integral + if not _special_function_patterns: + a = Wild('a', exclude=[_symbol], properties=[lambda x: not x.is_zero]) + b = Wild('b', exclude=[_symbol]) + c = Wild('c', exclude=[_symbol]) + d = Wild('d', exclude=[_symbol], properties=[lambda x: not x.is_zero]) + e = Wild('e', exclude=[_symbol], properties=[ + lambda x: not (x.is_nonnegative and x.is_integer)]) + _wilds.extend((a, b, c, d, e)) + # patterns consist of a SymPy class, a wildcard expr, an optional + # condition coded as a lambda (when Wild properties are not enough), + # followed by an applicable rule + linear_pattern = a*_symbol + b + quadratic_pattern = a*_symbol**2 + b*_symbol + c + _special_function_patterns.extend(( + (Mul, exp(linear_pattern, evaluate=False)/_symbol, None, EiRule), + (Mul, cos(linear_pattern, evaluate=False)/_symbol, None, CiRule), + (Mul, cosh(linear_pattern, evaluate=False)/_symbol, None, ChiRule), + (Mul, sin(linear_pattern, evaluate=False)/_symbol, None, SiRule), + (Mul, sinh(linear_pattern, evaluate=False)/_symbol, None, ShiRule), + (Pow, 1/log(linear_pattern, evaluate=False), None, LiRule), + (exp, exp(quadratic_pattern, evaluate=False), None, ErfRule), + (sin, sin(quadratic_pattern, evaluate=False), None, FresnelSRule), + (cos, cos(quadratic_pattern, evaluate=False), None, FresnelCRule), + (Mul, _symbol**e*exp(a*_symbol, evaluate=False), None, UpperGammaRule), + (Mul, polylog(b, a*_symbol, evaluate=False)/_symbol, None, PolylogRule), + (Pow, 1/sqrt(a - d*sin(_symbol, evaluate=False)**2), + lambda a, d: a != d, EllipticFRule), + (Pow, sqrt(a - d*sin(_symbol, evaluate=False)**2), + lambda a, d: a != d, EllipticERule), + )) + _integrand = integrand.subs(symbol, _symbol) + for type_, pattern, constraint, rule in _special_function_patterns: + if isinstance(_integrand, type_): + match = _integrand.match(pattern) + if match: + wild_vals = tuple(match.get(w) for w in _wilds + if match.get(w) is not None) + if constraint is None or constraint(*wild_vals): + return rule(integrand, symbol, *wild_vals) + + +def _add_degenerate_step(generic_cond, generic_step: Rule, degenerate_step: Rule | None) -> Rule: + if degenerate_step is None: + return generic_step + if isinstance(generic_step, PiecewiseRule): + subfunctions = [(substep, (cond & generic_cond).simplify()) + for substep, cond in generic_step.subfunctions] + else: + subfunctions = [(generic_step, generic_cond)] + if isinstance(degenerate_step, PiecewiseRule): + subfunctions += degenerate_step.subfunctions + else: + subfunctions.append((degenerate_step, S.true)) + return PiecewiseRule(generic_step.integrand, generic_step.variable, subfunctions) + + +def nested_pow_rule(integral: IntegralInfo): + # nested (c*(a+b*x)**d)**e + integrand, x = integral + + a_ = Wild('a', exclude=[x]) + b_ = Wild('b', exclude=[x, 0]) + pattern = a_+b_*x + generic_cond = S.true + + class NoMatch(Exception): + pass + + def _get_base_exp(expr: Expr) -> tuple[Expr, Expr]: + if not expr.has_free(x): + return S.One, S.Zero + if expr.is_Mul: + _, terms = expr.as_coeff_mul() + if not terms: + return S.One, S.Zero + results = [_get_base_exp(term) for term in terms] + bases = {b for b, _ in results} + bases.discard(S.One) + if len(bases) == 1: + return bases.pop(), Add(*(e for _, e in results)) + raise NoMatch + if expr.is_Pow: + b, e = expr.base, expr.exp # type: ignore + if e.has_free(x): + raise NoMatch + base_, sub_exp = _get_base_exp(b) + return base_, sub_exp * e + match = expr.match(pattern) + if match: + a, b = match[a_], match[b_] + base_ = x + a/b + nonlocal generic_cond + generic_cond = Ne(b, 0) + return base_, S.One + raise NoMatch + + try: + base, exp_ = _get_base_exp(integrand) + except NoMatch: + return + if generic_cond is S.true: + degenerate_step = None + else: + # equivalent with subs(b, 0) but no need to find b + degenerate_step = ConstantRule(integrand.subs(x, 0), x) + generic_step = NestedPowRule(integrand, x, base, exp_) + return _add_degenerate_step(generic_cond, generic_step, degenerate_step) + + +def inverse_trig_rule(integral: IntegralInfo, degenerate=True): + """ + Set degenerate=False on recursive call where coefficient of quadratic term + is assumed non-zero. + """ + integrand, symbol = integral + base, exp = integrand.as_base_exp() + a = Wild('a', exclude=[symbol]) + b = Wild('b', exclude=[symbol]) + c = Wild('c', exclude=[symbol, 0]) + match = base.match(a + b*symbol + c*symbol**2) + + if not match: + return + + def make_inverse_trig(RuleClass, a, sign_a, c, sign_c, h) -> Rule: + u_var = Dummy("u") + rewritten = 1/sqrt(sign_a*a + sign_c*c*(symbol-h)**2) # a>0, c>0 + quadratic_base = sqrt(c/a)*(symbol-h) + constant = 1/sqrt(c) + u_func = None + if quadratic_base is not symbol: + u_func = quadratic_base + quadratic_base = u_var + standard_form = 1/sqrt(sign_a + sign_c*quadratic_base**2) + substep = RuleClass(standard_form, quadratic_base) + if constant != 1: + substep = ConstantTimesRule(constant*standard_form, symbol, constant, standard_form, substep) + if u_func is not None: + substep = URule(rewritten, symbol, u_var, u_func, substep) + if h != 0: + substep = CompleteSquareRule(integrand, symbol, rewritten, substep) + return substep + + a, b, c = [match.get(i, S.Zero) for i in (a, b, c)] + generic_cond = Ne(c, 0) + if not degenerate or generic_cond is S.true: + degenerate_step = None + elif b.is_zero: + degenerate_step = ConstantRule(a ** exp, symbol) + else: + degenerate_step = sqrt_linear_rule(IntegralInfo((a + b * symbol) ** exp, symbol)) + + if simplify(2*exp + 1) == 0: + h, k = -b/(2*c), a - b**2/(4*c) # rewrite base to k + c*(symbol-h)**2 + non_square_cond = Ne(k, 0) + square_step = None + if non_square_cond is not S.true: + square_step = NestedPowRule(1/sqrt(c*(symbol-h)**2), symbol, symbol-h, S.NegativeOne) + if non_square_cond is S.false: + return square_step + generic_step = ReciprocalSqrtQuadraticRule(integrand, symbol, a, b, c) + step = _add_degenerate_step(non_square_cond, generic_step, square_step) + if k.is_real and c.is_real: + # list of ((rule, base_exp, a, sign_a, b, sign_b), condition) + rules = [] + for args, cond in ( # don't apply ArccoshRule to x**2-1 + ((ArcsinRule, k, 1, -c, -1, h), And(k > 0, c < 0)), # 1-x**2 + ((ArcsinhRule, k, 1, c, 1, h), And(k > 0, c > 0)), # 1+x**2 + ): + if cond is S.true: + return make_inverse_trig(*args) + if cond is not S.false: + rules.append((make_inverse_trig(*args), cond)) + if rules: + if not k.is_positive: # conditions are not thorough, need fall back rule + rules.append((generic_step, S.true)) + step = PiecewiseRule(integrand, symbol, rules) + else: + step = generic_step + return _add_degenerate_step(generic_cond, step, degenerate_step) + if exp == S.Half: + step = SqrtQuadraticRule(integrand, symbol, a, b, c) + return _add_degenerate_step(generic_cond, step, degenerate_step) + + +def add_rule(integral): + integrand, symbol = integral + results = [integral_steps(g, symbol) + for g in integrand.as_ordered_terms()] + return None if None in results else AddRule(integrand, symbol, results) + + +def mul_rule(integral: IntegralInfo): + integrand, symbol = integral + + # Constant times function case + coeff, f = integrand.as_independent(symbol) + if coeff != 1: + next_step = integral_steps(f, symbol) + if next_step is not None: + return ConstantTimesRule(integrand, symbol, coeff, f, next_step) + + +def _parts_rule(integrand, symbol) -> tuple[Expr, Expr, Expr, Expr, Rule] | None: + # LIATE rule: + # log, inverse trig, algebraic, trigonometric, exponential + def pull_out_algebraic(integrand): + integrand = integrand.cancel().together() + # iterating over Piecewise args would not work here + algebraic = ([] if isinstance(integrand, Piecewise) or not integrand.is_Mul + else [arg for arg in integrand.args if arg.is_algebraic_expr(symbol)]) + if algebraic: + u = Mul(*algebraic) + dv = (integrand / u).cancel() + return u, dv + + def pull_out_u(*functions) -> Callable[[Expr], tuple[Expr, Expr] | None]: + def pull_out_u_rl(integrand: Expr) -> tuple[Expr, Expr] | None: + if any(integrand.has(f) for f in functions): + args = [arg for arg in integrand.args + if any(isinstance(arg, cls) for cls in functions)] + if args: + u = Mul(*args) # type: ignore + dv = integrand / u + return u, dv + return None + + return pull_out_u_rl + + liate_rules = [pull_out_u(log), pull_out_u(*inverse_trig_functions), + pull_out_algebraic, pull_out_u(sin, cos), + pull_out_u(exp)] + + + dummy = Dummy("temporary") + # we can integrate log(x) and atan(x) by setting dv = 1 + if isinstance(integrand, (log, *inverse_trig_functions)): + integrand = dummy * integrand + + for index, rule in enumerate(liate_rules): + result = rule(integrand) + + if result: + u, dv = result + + # Don't pick u to be a constant if possible + if symbol not in u.free_symbols and not u.has(dummy): + return None + + u = u.subs(dummy, 1) + dv = dv.subs(dummy, 1) + + # Don't pick a non-polynomial algebraic to be differentiated + if rule == pull_out_algebraic and not u.is_polynomial(symbol): + return None + # Don't trade one logarithm for another + if isinstance(u, log): + rec_dv = 1/dv + if (rec_dv.is_polynomial(symbol) and + degree(rec_dv, symbol) == 1): + return None + + # Can integrate a polynomial times OrthogonalPolynomial + if rule == pull_out_algebraic: + if dv.is_Derivative or dv.has(TrigonometricFunction) or \ + isinstance(dv, OrthogonalPolynomial): + v_step = integral_steps(dv, symbol) + if v_step.contains_dont_know(): + return None + else: + du = u.diff(symbol) + v = v_step.eval() + return u, dv, v, du, v_step + + # make sure dv is amenable to integration + accept = False + if index < 2: # log and inverse trig are usually worth trying + accept = True + elif (rule == pull_out_algebraic and dv.args and + all(isinstance(a, (sin, cos, exp)) + for a in dv.args)): + accept = True + else: + for lrule in liate_rules[index + 1:]: + r = lrule(integrand) + if r and r[0].subs(dummy, 1).equals(dv): + accept = True + break + + if accept: + du = u.diff(symbol) + v_step = integral_steps(simplify(dv), symbol) + if not v_step.contains_dont_know(): + v = v_step.eval() + return u, dv, v, du, v_step + return None + + +def parts_rule(integral): + integrand, symbol = integral + constant, integrand = integrand.as_coeff_Mul() + + result = _parts_rule(integrand, symbol) + + steps = [] + if result: + u, dv, v, du, v_step = result + debug("u : {}, dv : {}, v : {}, du : {}, v_step: {}".format(u, dv, v, du, v_step)) + steps.append(result) + + if isinstance(v, Integral): + return + + # Set a limit on the number of times u can be used + if isinstance(u, (sin, cos, exp, sinh, cosh)): + cachekey = u.xreplace({symbol: _cache_dummy}) + if _parts_u_cache[cachekey] > 2: + return + _parts_u_cache[cachekey] += 1 + + # Try cyclic integration by parts a few times + for _ in range(4): + debug("Cyclic integration {} with v: {}, du: {}, integrand: {}".format(_, v, du, integrand)) + coefficient = ((v * du) / integrand).cancel() + if coefficient == 1: + break + if symbol not in coefficient.free_symbols: + rule = CyclicPartsRule(integrand, symbol, + [PartsRule(None, None, u, dv, v_step, None) + for (u, dv, v, du, v_step) in steps], + (-1) ** len(steps) * coefficient) + if (constant != 1) and rule: + rule = ConstantTimesRule(constant * integrand, symbol, constant, integrand, rule) + return rule + + # _parts_rule is sensitive to constants, factor it out + next_constant, next_integrand = (v * du).as_coeff_Mul() + result = _parts_rule(next_integrand, symbol) + + if result: + u, dv, v, du, v_step = result + u *= next_constant + du *= next_constant + steps.append((u, dv, v, du, v_step)) + else: + break + + def make_second_step(steps, integrand): + if steps: + u, dv, v, du, v_step = steps[0] + return PartsRule(integrand, symbol, u, dv, v_step, make_second_step(steps[1:], v * du)) + return integral_steps(integrand, symbol) + + if steps: + u, dv, v, du, v_step = steps[0] + rule = PartsRule(integrand, symbol, u, dv, v_step, make_second_step(steps[1:], v * du)) + if (constant != 1) and rule: + rule = ConstantTimesRule(constant * integrand, symbol, constant, integrand, rule) + return rule + + +def trig_rule(integral): + integrand, symbol = integral + if integrand == sin(symbol): + return SinRule(integrand, symbol) + if integrand == cos(symbol): + return CosRule(integrand, symbol) + if integrand == sec(symbol)**2: + return Sec2Rule(integrand, symbol) + if integrand == csc(symbol)**2: + return Csc2Rule(integrand, symbol) + + if isinstance(integrand, tan): + rewritten = sin(*integrand.args) / cos(*integrand.args) + elif isinstance(integrand, cot): + rewritten = cos(*integrand.args) / sin(*integrand.args) + elif isinstance(integrand, sec): + arg = integrand.args[0] + rewritten = ((sec(arg)**2 + tan(arg) * sec(arg)) / + (sec(arg) + tan(arg))) + elif isinstance(integrand, csc): + arg = integrand.args[0] + rewritten = ((csc(arg)**2 + cot(arg) * csc(arg)) / + (csc(arg) + cot(arg))) + else: + return + + return RewriteRule(integrand, symbol, rewritten, integral_steps(rewritten, symbol)) + +def trig_product_rule(integral: IntegralInfo): + integrand, symbol = integral + if integrand == sec(symbol) * tan(symbol): + return SecTanRule(integrand, symbol) + if integrand == csc(symbol) * cot(symbol): + return CscCotRule(integrand, symbol) + + +def quadratic_denom_rule(integral): + integrand, symbol = integral + a = Wild('a', exclude=[symbol]) + b = Wild('b', exclude=[symbol]) + c = Wild('c', exclude=[symbol]) + + match = integrand.match(a / (b * symbol ** 2 + c)) + + if match: + a, b, c = match[a], match[b], match[c] + general_rule = ArctanRule(integrand, symbol, a, b, c) + if b.is_extended_real and c.is_extended_real: + positive_cond = c/b > 0 + if positive_cond is S.true: + return general_rule + coeff = a/(2*sqrt(-c)*sqrt(b)) + constant = sqrt(-c/b) + r1 = 1/(symbol-constant) + r2 = 1/(symbol+constant) + log_steps = [ReciprocalRule(r1, symbol, symbol-constant), + ConstantTimesRule(-r2, symbol, -1, r2, ReciprocalRule(r2, symbol, symbol+constant))] + rewritten = sub = r1 - r2 + negative_step = AddRule(sub, symbol, log_steps) + if coeff != 1: + rewritten = Mul(coeff, sub, evaluate=False) + negative_step = ConstantTimesRule(rewritten, symbol, coeff, sub, negative_step) + negative_step = RewriteRule(integrand, symbol, rewritten, negative_step) + if positive_cond is S.false: + return negative_step + return PiecewiseRule(integrand, symbol, [(general_rule, positive_cond), (negative_step, S.true)]) + + power = PowerRule(integrand, symbol, symbol, -2) + if b != 1: + power = ConstantTimesRule(integrand, symbol, 1/b, symbol**-2, power) + + return PiecewiseRule(integrand, symbol, [(general_rule, Ne(c, 0)), (power, True)]) + + d = Wild('d', exclude=[symbol]) + match2 = integrand.match(a / (b * symbol ** 2 + c * symbol + d)) + if match2: + b, c = match2[b], match2[c] + if b.is_zero: + return + u = Dummy('u') + u_func = symbol + c/(2*b) + integrand2 = integrand.subs(symbol, u - c / (2*b)) + next_step = integral_steps(integrand2, u) + if next_step: + return URule(integrand2, symbol, u, u_func, next_step) + else: + return + e = Wild('e', exclude=[symbol]) + match3 = integrand.match((a* symbol + b) / (c * symbol ** 2 + d * symbol + e)) + if match3: + a, b, c, d, e = match3[a], match3[b], match3[c], match3[d], match3[e] + if c.is_zero: + return + denominator = c * symbol**2 + d * symbol + e + const = a/(2*c) + numer1 = (2*c*symbol+d) + numer2 = - const*d + b + u = Dummy('u') + step1 = URule(integrand, symbol, + u, denominator, integral_steps(u**(-1), u)) + if const != 1: + step1 = ConstantTimesRule(const*numer1/denominator, symbol, + const, numer1/denominator, step1) + if numer2.is_zero: + return step1 + step2 = integral_steps(numer2/denominator, symbol) + substeps = AddRule(integrand, symbol, [step1, step2]) + rewriten = const*numer1/denominator+numer2/denominator + return RewriteRule(integrand, symbol, rewriten, substeps) + + return + + +def sqrt_linear_rule(integral: IntegralInfo): + """ + Substitute common (a+b*x)**(1/n) + """ + integrand, x = integral + a = Wild('a', exclude=[x]) + b = Wild('b', exclude=[x, 0]) + a0 = b0 = 0 + bases, qs, bs = [], [], [] + for pow_ in integrand.find(Pow): # collect all (a+b*x)**(p/q) + base, exp_ = pow_.base, pow_.exp + if exp_.is_Integer or x not in base.free_symbols: # skip 1/x and sqrt(2) + continue + if not exp_.is_Rational: # exclude x**pi + return + match = base.match(a+b*x) + if not match: # skip non-linear + continue # for sqrt(x+sqrt(x)), although base is non-linear, we can still substitute sqrt(x) + a1, b1 = match[a], match[b] + if a0*b1 != a1*b0 or not (b0/b1).is_nonnegative: # cannot transform sqrt(x) to sqrt(x+1) or sqrt(-x) + return + if b0 == 0 or (b0/b1 > 1) is S.true: # choose the latter of sqrt(2*x) and sqrt(x) as representative + a0, b0 = a1, b1 + bases.append(base) + bs.append(b1) + qs.append(exp_.q) + if b0 == 0: # no such pattern found + return + q0: Integer = lcm_list(qs) + u_x = (a0 + b0*x)**(1/q0) + u = Dummy("u") + substituted = integrand.subs({base**(S.One/q): (b/b0)**(S.One/q)*u**(q0/q) + for base, b, q in zip(bases, bs, qs)}).subs(x, (u**q0-a0)/b0) + substep = integral_steps(substituted*u**(q0-1)*q0/b0, u) + if not substep.contains_dont_know(): + step: Rule = URule(integrand, x, u, u_x, substep) + generic_cond = Ne(b0, 0) + if generic_cond is not S.true: # possible degenerate case + simplified = integrand.subs(dict.fromkeys(bs, 0)) + degenerate_step = integral_steps(simplified, x) + step = PiecewiseRule(integrand, x, [(step, generic_cond), (degenerate_step, S.true)]) + return step + + +def sqrt_quadratic_rule(integral: IntegralInfo, degenerate=True): + integrand, x = integral + a = Wild('a', exclude=[x]) + b = Wild('b', exclude=[x]) + c = Wild('c', exclude=[x, 0]) + f = Wild('f') + n = Wild('n', properties=[lambda n: n.is_Integer and n.is_odd]) + match = integrand.match(f*sqrt(a+b*x+c*x**2)**n) + if not match: + return + a, b, c, f, n = match[a], match[b], match[c], match[f], match[n] + f_poly = f.as_poly(x) + if f_poly is None: + return + + generic_cond = Ne(c, 0) + if not degenerate or generic_cond is S.true: + degenerate_step = None + elif b.is_zero: + degenerate_step = integral_steps(f*sqrt(a)**n, x) + else: + degenerate_step = sqrt_linear_rule(IntegralInfo(f*sqrt(a+b*x)**n, x)) + + def sqrt_quadratic_denom_rule(numer_poly: Poly, integrand: Expr): + denom = sqrt(a+b*x+c*x**2) + deg = numer_poly.degree() + if deg <= 1: + # integrand == (d+e*x)/sqrt(a+b*x+c*x**2) + e, d = numer_poly.all_coeffs() if deg == 1 else (S.Zero, numer_poly.as_expr()) + # rewrite numerator to A*(2*c*x+b) + B + A = e/(2*c) + B = d-A*b + pre_substitute = (2*c*x+b)/denom + constant_step: Rule | None = None + linear_step: Rule | None = None + if A != 0: + u = Dummy("u") + pow_rule = PowerRule(1/sqrt(u), u, u, -S.Half) + linear_step = URule(pre_substitute, x, u, a+b*x+c*x**2, pow_rule) + if A != 1: + linear_step = ConstantTimesRule(A*pre_substitute, x, A, pre_substitute, linear_step) + if B != 0: + constant_step = inverse_trig_rule(IntegralInfo(1/denom, x), degenerate=False) + if B != 1: + constant_step = ConstantTimesRule(B/denom, x, B, 1/denom, constant_step) # type: ignore + if linear_step and constant_step: + add = Add(A*pre_substitute, B/denom, evaluate=False) + step: Rule | None = RewriteRule(integrand, x, add, AddRule(add, x, [linear_step, constant_step])) + else: + step = linear_step or constant_step + else: + coeffs = numer_poly.all_coeffs() + step = SqrtQuadraticDenomRule(integrand, x, a, b, c, coeffs) + return step + + if n > 0: # rewrite poly * sqrt(s)**(2*k-1) to poly*s**k / sqrt(s) + numer_poly = f_poly * (a+b*x+c*x**2)**((n+1)/2) + rewritten = numer_poly.as_expr()/sqrt(a+b*x+c*x**2) + substep = sqrt_quadratic_denom_rule(numer_poly, rewritten) + generic_step = RewriteRule(integrand, x, rewritten, substep) + elif n == -1: + generic_step = sqrt_quadratic_denom_rule(f_poly, integrand) + else: + return # todo: handle n < -1 case + return _add_degenerate_step(generic_cond, generic_step, degenerate_step) + + +def hyperbolic_rule(integral: tuple[Expr, Symbol]): + integrand, symbol = integral + if isinstance(integrand, HyperbolicFunction) and integrand.args[0] == symbol: + if integrand.func == sinh: + return SinhRule(integrand, symbol) + if integrand.func == cosh: + return CoshRule(integrand, symbol) + u = Dummy('u') + if integrand.func == tanh: + rewritten = sinh(symbol)/cosh(symbol) + return RewriteRule(integrand, symbol, rewritten, + URule(rewritten, symbol, u, cosh(symbol), ReciprocalRule(1/u, u, u))) + if integrand.func == coth: + rewritten = cosh(symbol)/sinh(symbol) + return RewriteRule(integrand, symbol, rewritten, + URule(rewritten, symbol, u, sinh(symbol), ReciprocalRule(1/u, u, u))) + else: + rewritten = integrand.rewrite(tanh) + if integrand.func == sech: + return RewriteRule(integrand, symbol, rewritten, + URule(rewritten, symbol, u, tanh(symbol/2), + ArctanRule(2/(u**2 + 1), u, S(2), S.One, S.One))) + if integrand.func == csch: + return RewriteRule(integrand, symbol, rewritten, + URule(rewritten, symbol, u, tanh(symbol/2), + ReciprocalRule(1/u, u, u))) + +@cacheit +def make_wilds(symbol): + a = Wild('a', exclude=[symbol]) + b = Wild('b', exclude=[symbol]) + m = Wild('m', exclude=[symbol], properties=[lambda n: isinstance(n, Integer)]) + n = Wild('n', exclude=[symbol], properties=[lambda n: isinstance(n, Integer)]) + + return a, b, m, n + +@cacheit +def sincos_pattern(symbol): + a, b, m, n = make_wilds(symbol) + pattern = sin(a*symbol)**m * cos(b*symbol)**n + + return pattern, a, b, m, n + +@cacheit +def tansec_pattern(symbol): + a, b, m, n = make_wilds(symbol) + pattern = tan(a*symbol)**m * sec(b*symbol)**n + + return pattern, a, b, m, n + +@cacheit +def cotcsc_pattern(symbol): + a, b, m, n = make_wilds(symbol) + pattern = cot(a*symbol)**m * csc(b*symbol)**n + + return pattern, a, b, m, n + +@cacheit +def heaviside_pattern(symbol): + m = Wild('m', exclude=[symbol]) + b = Wild('b', exclude=[symbol]) + g = Wild('g') + pattern = Heaviside(m*symbol + b) * g + + return pattern, m, b, g + +def uncurry(func): + def uncurry_rl(args): + return func(*args) + return uncurry_rl + +def trig_rewriter(rewrite): + def trig_rewriter_rl(args): + a, b, m, n, integrand, symbol = args + rewritten = rewrite(a, b, m, n, integrand, symbol) + if rewritten != integrand: + return RewriteRule(integrand, symbol, rewritten, integral_steps(rewritten, symbol)) + return trig_rewriter_rl + +sincos_botheven_condition = uncurry( + lambda a, b, m, n, i, s: m.is_even and n.is_even and + m.is_nonnegative and n.is_nonnegative) + +sincos_botheven = trig_rewriter( + lambda a, b, m, n, i, symbol: ( (((1 - cos(2*a*symbol)) / 2) ** (m / 2)) * + (((1 + cos(2*b*symbol)) / 2) ** (n / 2)) )) + +sincos_sinodd_condition = uncurry(lambda a, b, m, n, i, s: m.is_odd and m >= 3) + +sincos_sinodd = trig_rewriter( + lambda a, b, m, n, i, symbol: ( (1 - cos(a*symbol)**2)**((m - 1) / 2) * + sin(a*symbol) * + cos(b*symbol) ** n)) + +sincos_cosodd_condition = uncurry(lambda a, b, m, n, i, s: n.is_odd and n >= 3) + +sincos_cosodd = trig_rewriter( + lambda a, b, m, n, i, symbol: ( (1 - sin(b*symbol)**2)**((n - 1) / 2) * + cos(b*symbol) * + sin(a*symbol) ** m)) + +tansec_seceven_condition = uncurry(lambda a, b, m, n, i, s: n.is_even and n >= 4) +tansec_seceven = trig_rewriter( + lambda a, b, m, n, i, symbol: ( (1 + tan(b*symbol)**2) ** (n/2 - 1) * + sec(b*symbol)**2 * + tan(a*symbol) ** m )) + +tansec_tanodd_condition = uncurry(lambda a, b, m, n, i, s: m.is_odd) +tansec_tanodd = trig_rewriter( + lambda a, b, m, n, i, symbol: ( (sec(a*symbol)**2 - 1) ** ((m - 1) / 2) * + tan(a*symbol) * + sec(b*symbol) ** n )) + +tan_tansquared_condition = uncurry(lambda a, b, m, n, i, s: m == 2 and n == 0) +tan_tansquared = trig_rewriter( + lambda a, b, m, n, i, symbol: ( sec(a*symbol)**2 - 1)) + +cotcsc_csceven_condition = uncurry(lambda a, b, m, n, i, s: n.is_even and n >= 4) +cotcsc_csceven = trig_rewriter( + lambda a, b, m, n, i, symbol: ( (1 + cot(b*symbol)**2) ** (n/2 - 1) * + csc(b*symbol)**2 * + cot(a*symbol) ** m )) + +cotcsc_cotodd_condition = uncurry(lambda a, b, m, n, i, s: m.is_odd) +cotcsc_cotodd = trig_rewriter( + lambda a, b, m, n, i, symbol: ( (csc(a*symbol)**2 - 1) ** ((m - 1) / 2) * + cot(a*symbol) * + csc(b*symbol) ** n )) + +def trig_sincos_rule(integral): + integrand, symbol = integral + + if any(integrand.has(f) for f in (sin, cos)): + pattern, a, b, m, n = sincos_pattern(symbol) + match = integrand.match(pattern) + if not match: + return + + return multiplexer({ + sincos_botheven_condition: sincos_botheven, + sincos_sinodd_condition: sincos_sinodd, + sincos_cosodd_condition: sincos_cosodd + })(tuple( + [match.get(i, S.Zero) for i in (a, b, m, n)] + + [integrand, symbol])) + +def trig_tansec_rule(integral): + integrand, symbol = integral + + integrand = integrand.subs({ + 1 / cos(symbol): sec(symbol) + }) + + if any(integrand.has(f) for f in (tan, sec)): + pattern, a, b, m, n = tansec_pattern(symbol) + match = integrand.match(pattern) + if not match: + return + + return multiplexer({ + tansec_tanodd_condition: tansec_tanodd, + tansec_seceven_condition: tansec_seceven, + tan_tansquared_condition: tan_tansquared + })(tuple( + [match.get(i, S.Zero) for i in (a, b, m, n)] + + [integrand, symbol])) + +def trig_cotcsc_rule(integral): + integrand, symbol = integral + integrand = integrand.subs({ + 1 / sin(symbol): csc(symbol), + 1 / tan(symbol): cot(symbol), + cos(symbol) / tan(symbol): cot(symbol) + }) + + if any(integrand.has(f) for f in (cot, csc)): + pattern, a, b, m, n = cotcsc_pattern(symbol) + match = integrand.match(pattern) + if not match: + return + + return multiplexer({ + cotcsc_cotodd_condition: cotcsc_cotodd, + cotcsc_csceven_condition: cotcsc_csceven + })(tuple( + [match.get(i, S.Zero) for i in (a, b, m, n)] + + [integrand, symbol])) + +def trig_sindouble_rule(integral): + integrand, symbol = integral + a = Wild('a', exclude=[sin(2*symbol)]) + match = integrand.match(sin(2*symbol)*a) + if match: + sin_double = 2*sin(symbol)*cos(symbol)/sin(2*symbol) + return integral_steps(integrand * sin_double, symbol) + +def trig_powers_products_rule(integral): + return do_one(null_safe(trig_sincos_rule), + null_safe(trig_tansec_rule), + null_safe(trig_cotcsc_rule), + null_safe(trig_sindouble_rule))(integral) + +def trig_substitution_rule(integral): + integrand, symbol = integral + A = Wild('a', exclude=[0, symbol]) + B = Wild('b', exclude=[0, symbol]) + theta = Dummy("theta") + target_pattern = A + B*symbol**2 + + matches = integrand.find(target_pattern) + for expr in matches: + match = expr.match(target_pattern) + a = match.get(A, S.Zero) + b = match.get(B, S.Zero) + + a_positive = ((a.is_number and a > 0) or a.is_positive) + b_positive = ((b.is_number and b > 0) or b.is_positive) + a_negative = ((a.is_number and a < 0) or a.is_negative) + b_negative = ((b.is_number and b < 0) or b.is_negative) + x_func = None + if a_positive and b_positive: + # a**2 + b*x**2. Assume sec(theta) > 0, -pi/2 < theta < pi/2 + x_func = (sqrt(a)/sqrt(b)) * tan(theta) + # Do not restrict the domain: tan(theta) takes on any real + # value on the interval -pi/2 < theta < pi/2 so x takes on + # any value + restriction = True + elif a_positive and b_negative: + # a**2 - b*x**2. Assume cos(theta) > 0, -pi/2 < theta < pi/2 + constant = sqrt(a)/sqrt(-b) + x_func = constant * sin(theta) + restriction = And(symbol > -constant, symbol < constant) + elif a_negative and b_positive: + # b*x**2 - a**2. Assume sin(theta) > 0, 0 < theta < pi + constant = sqrt(-a)/sqrt(b) + x_func = constant * sec(theta) + restriction = And(symbol > -constant, symbol < constant) + if x_func: + # Manually simplify sqrt(trig(theta)**2) to trig(theta) + # Valid due to assumed domain restriction + substitutions = {} + for f in [sin, cos, tan, + sec, csc, cot]: + substitutions[sqrt(f(theta)**2)] = f(theta) + substitutions[sqrt(f(theta)**(-2))] = 1/f(theta) + + replaced = integrand.subs(symbol, x_func).trigsimp() + replaced = manual_subs(replaced, substitutions) + if not replaced.has(symbol): + replaced *= manual_diff(x_func, theta) + replaced = replaced.trigsimp() + secants = replaced.find(1/cos(theta)) + if secants: + replaced = replaced.xreplace({ + 1/cos(theta): sec(theta) + }) + + substep = integral_steps(replaced, theta) + if not substep.contains_dont_know(): + return TrigSubstitutionRule(integrand, symbol, + theta, x_func, replaced, substep, restriction) + +def heaviside_rule(integral): + integrand, symbol = integral + pattern, m, b, g = heaviside_pattern(symbol) + match = integrand.match(pattern) + if match and 0 != match[g]: + # f = Heaviside(m*x + b)*g + substep = integral_steps(match[g], symbol) + m, b = match[m], match[b] + return HeavisideRule(integrand, symbol, m*symbol + b, -b/m, substep) + + +def dirac_delta_rule(integral: IntegralInfo): + integrand, x = integral + if len(integrand.args) == 1: + n = S.Zero + else: + n = integrand.args[1] # type: ignore + if not n.is_Integer or n < 0: + return + a, b = Wild('a', exclude=[x]), Wild('b', exclude=[x, 0]) + match = integrand.args[0].match(a+b*x) + if not match: + return + a, b = match[a], match[b] + generic_cond = Ne(b, 0) + if generic_cond is S.true: + degenerate_step = None + else: + degenerate_step = ConstantRule(DiracDelta(a, n), x) + generic_step = DiracDeltaRule(integrand, x, n, a, b) + return _add_degenerate_step(generic_cond, generic_step, degenerate_step) + + +def substitution_rule(integral): + integrand, symbol = integral + + u_var = Dummy("u") + substitutions = find_substitutions(integrand, symbol, u_var) + count = 0 + if substitutions: + debug("List of Substitution Rules") + ways = [] + for u_func, c, substituted in substitutions: + subrule = integral_steps(substituted, u_var) + count = count + 1 + debug("Rule {}: {}".format(count, subrule)) + + if subrule.contains_dont_know(): + continue + + if simplify(c - 1) != 0: + _, denom = c.as_numer_denom() + if subrule: + subrule = ConstantTimesRule(c * substituted, u_var, c, substituted, subrule) + + if denom.free_symbols: + piecewise = [] + could_be_zero = [] + + if isinstance(denom, Mul): + could_be_zero = denom.args + else: + could_be_zero.append(denom) + + for expr in could_be_zero: + if not fuzzy_not(expr.is_zero): + substep = integral_steps(manual_subs(integrand, expr, 0), symbol) + + if substep: + piecewise.append(( + substep, + Eq(expr, 0) + )) + piecewise.append((subrule, True)) + subrule = PiecewiseRule(substituted, symbol, piecewise) + + ways.append(URule(integrand, symbol, u_var, u_func, subrule)) + + if len(ways) > 1: + return AlternativeRule(integrand, symbol, ways) + elif ways: + return ways[0] + + +partial_fractions_rule = rewriter( + lambda integrand, symbol: integrand.is_rational_function(), + lambda integrand, symbol: integrand.apart(symbol)) + +cancel_rule = rewriter( + # lambda integrand, symbol: integrand.is_algebraic_expr(), + # lambda integrand, symbol: isinstance(integrand, Mul), + lambda integrand, symbol: True, + lambda integrand, symbol: integrand.cancel()) + +distribute_expand_rule = rewriter( + lambda integrand, symbol: ( + isinstance(integrand, (Pow, Mul)) or all(arg.is_Pow or arg.is_polynomial(symbol) for arg in integrand.args)), + lambda integrand, symbol: integrand.expand()) + +trig_expand_rule = rewriter( + # If there are trig functions with different arguments, expand them + lambda integrand, symbol: ( + len({a.args[0] for a in integrand.atoms(TrigonometricFunction)}) > 1), + lambda integrand, symbol: integrand.expand(trig=True)) + +def derivative_rule(integral): + integrand = integral[0] + diff_variables = integrand.variables + undifferentiated_function = integrand.expr + integrand_variables = undifferentiated_function.free_symbols + + if integral.symbol in integrand_variables: + if integral.symbol in diff_variables: + return DerivativeRule(*integral) + else: + return DontKnowRule(integrand, integral.symbol) + else: + return ConstantRule(*integral) + +def rewrites_rule(integral): + integrand, symbol = integral + + if integrand.match(1/cos(symbol)): + rewritten = integrand.subs(1/cos(symbol), sec(symbol)) + return RewriteRule(integrand, symbol, rewritten, integral_steps(rewritten, symbol)) + +def fallback_rule(integral): + return DontKnowRule(*integral) + +# Cache is used to break cyclic integrals. +# Need to use the same dummy variable in cached expressions for them to match. +# Also record "u" of integration by parts, to avoid infinite repetition. +_integral_cache: dict[Expr, Expr | None] = {} +_parts_u_cache: dict[Expr, int] = defaultdict(int) +_cache_dummy = Dummy("z") + +def integral_steps(integrand, symbol, **options): + """Returns the steps needed to compute an integral. + + Explanation + =========== + + This function attempts to mirror what a student would do by hand as + closely as possible. + + SymPy Gamma uses this to provide a step-by-step explanation of an + integral. The code it uses to format the results of this function can be + found at + https://github.com/sympy/sympy_gamma/blob/master/app/logic/intsteps.py. + + Examples + ======== + + >>> from sympy import exp, sin + >>> from sympy.integrals.manualintegrate import integral_steps + >>> from sympy.abc import x + >>> print(repr(integral_steps(exp(x) / (1 + exp(2 * x)), x))) \ + # doctest: +NORMALIZE_WHITESPACE + URule(integrand=exp(x)/(exp(2*x) + 1), variable=x, u_var=_u, u_func=exp(x), + substep=ArctanRule(integrand=1/(_u**2 + 1), variable=_u, a=1, b=1, c=1)) + >>> print(repr(integral_steps(sin(x), x))) \ + # doctest: +NORMALIZE_WHITESPACE + SinRule(integrand=sin(x), variable=x) + >>> print(repr(integral_steps((x**2 + 3)**2, x))) \ + # doctest: +NORMALIZE_WHITESPACE + RewriteRule(integrand=(x**2 + 3)**2, variable=x, rewritten=x**4 + 6*x**2 + 9, + substep=AddRule(integrand=x**4 + 6*x**2 + 9, variable=x, + substeps=[PowerRule(integrand=x**4, variable=x, base=x, exp=4), + ConstantTimesRule(integrand=6*x**2, variable=x, constant=6, other=x**2, + substep=PowerRule(integrand=x**2, variable=x, base=x, exp=2)), + ConstantRule(integrand=9, variable=x)])) + + Returns + ======= + + rule : Rule + The first step; most rules have substeps that must also be + considered. These substeps can be evaluated using ``manualintegrate`` + to obtain a result. + + """ + cachekey = integrand.xreplace({symbol: _cache_dummy}) + if cachekey in _integral_cache: + if _integral_cache[cachekey] is None: + # Stop this attempt, because it leads around in a loop + return DontKnowRule(integrand, symbol) + else: + # TODO: This is for future development, as currently + # _integral_cache gets no values other than None + return (_integral_cache[cachekey].xreplace(_cache_dummy, symbol), + symbol) + else: + _integral_cache[cachekey] = None + + integral = IntegralInfo(integrand, symbol) + + def key(integral): + integrand = integral.integrand + + if symbol not in integrand.free_symbols: + return Number + for cls in (Symbol, TrigonometricFunction, OrthogonalPolynomial): + if isinstance(integrand, cls): + return cls + return type(integrand) + + def integral_is_subclass(*klasses): + def _integral_is_subclass(integral): + k = key(integral) + return k and issubclass(k, klasses) + return _integral_is_subclass + + result = do_one( + null_safe(special_function_rule), + null_safe(switch(key, { + Pow: do_one(null_safe(power_rule), null_safe(inverse_trig_rule), + null_safe(sqrt_linear_rule), + null_safe(quadratic_denom_rule)), + Symbol: power_rule, + exp: exp_rule, + Add: add_rule, + Mul: do_one(null_safe(mul_rule), null_safe(trig_product_rule), + null_safe(heaviside_rule), null_safe(quadratic_denom_rule), + null_safe(sqrt_linear_rule), + null_safe(sqrt_quadratic_rule)), + Derivative: derivative_rule, + TrigonometricFunction: trig_rule, + Heaviside: heaviside_rule, + DiracDelta: dirac_delta_rule, + OrthogonalPolynomial: orthogonal_poly_rule, + Number: constant_rule + })), + do_one( + null_safe(trig_rule), + null_safe(hyperbolic_rule), + null_safe(alternatives( + rewrites_rule, + substitution_rule, + condition( + integral_is_subclass(Mul, Pow), + partial_fractions_rule), + condition( + integral_is_subclass(Mul, Pow), + cancel_rule), + condition( + integral_is_subclass(Mul, log, + *inverse_trig_functions), + parts_rule), + condition( + integral_is_subclass(Mul, Pow), + distribute_expand_rule), + trig_powers_products_rule, + trig_expand_rule + )), + null_safe(condition(integral_is_subclass(Mul, Pow), nested_pow_rule)), + null_safe(trig_substitution_rule) + ), + fallback_rule)(integral) + del _integral_cache[cachekey] + return result + + +def manualintegrate(f, var): + """manualintegrate(f, var) + + Explanation + =========== + + Compute indefinite integral of a single variable using an algorithm that + resembles what a student would do by hand. + + Unlike :func:`~.integrate`, var can only be a single symbol. + + Examples + ======== + + >>> from sympy import sin, cos, tan, exp, log, integrate + >>> from sympy.integrals.manualintegrate import manualintegrate + >>> from sympy.abc import x + >>> manualintegrate(1 / x, x) + log(x) + >>> integrate(1/x) + log(x) + >>> manualintegrate(log(x), x) + x*log(x) - x + >>> integrate(log(x)) + x*log(x) - x + >>> manualintegrate(exp(x) / (1 + exp(2 * x)), x) + atan(exp(x)) + >>> integrate(exp(x) / (1 + exp(2 * x))) + RootSum(4*_z**2 + 1, Lambda(_i, _i*log(2*_i + exp(x)))) + >>> manualintegrate(cos(x)**4 * sin(x), x) + -cos(x)**5/5 + >>> integrate(cos(x)**4 * sin(x), x) + -cos(x)**5/5 + >>> manualintegrate(cos(x)**4 * sin(x)**3, x) + cos(x)**7/7 - cos(x)**5/5 + >>> integrate(cos(x)**4 * sin(x)**3, x) + cos(x)**7/7 - cos(x)**5/5 + >>> manualintegrate(tan(x), x) + -log(cos(x)) + >>> integrate(tan(x), x) + -log(cos(x)) + + See Also + ======== + + sympy.integrals.integrals.integrate + sympy.integrals.integrals.Integral.doit + sympy.integrals.integrals.Integral + """ + result = integral_steps(f, var).eval() + # Clear the cache of u-parts + _parts_u_cache.clear() + # If we got Piecewise with two parts, put generic first + if isinstance(result, Piecewise) and len(result.args) == 2: + cond = result.args[0][1] + if isinstance(cond, Eq) and result.args[1][1] == True: + result = result.func( + (result.args[1][0], Ne(*cond.args)), + (result.args[0][0], True)) + return result diff --git a/.venv/lib/python3.13/site-packages/sympy/integrals/meijerint.py b/.venv/lib/python3.13/site-packages/sympy/integrals/meijerint.py new file mode 100644 index 0000000000000000000000000000000000000000..89d8401c7eee0147df2e824dc1dc50b59ca7f0be --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/integrals/meijerint.py @@ -0,0 +1,2191 @@ +""" +Integrate functions by rewriting them as Meijer G-functions. + +There are three user-visible functions that can be used by other parts of the +sympy library to solve various integration problems: + +- meijerint_indefinite +- meijerint_definite +- meijerint_inversion + +They can be used to compute, respectively, indefinite integrals, definite +integrals over intervals of the real line, and inverse laplace-type integrals +(from c-I*oo to c+I*oo). See the respective docstrings for details. + +The main references for this are: + +[L] Luke, Y. L. (1969), The Special Functions and Their Approximations, + Volume 1 + +[R] Kelly B. Roach. Meijer G Function Representations. + In: Proceedings of the 1997 International Symposium on Symbolic and + Algebraic Computation, pages 205-211, New York, 1997. ACM. + +[P] A. P. Prudnikov, Yu. A. Brychkov and O. I. Marichev (1990). + Integrals and Series: More Special Functions, Vol. 3,. + Gordon and Breach Science Publisher +""" + +from __future__ import annotations +import itertools + +from sympy import SYMPY_DEBUG +from sympy.core import S, Expr +from sympy.core.add import Add +from sympy.core.basic import Basic +from sympy.core.cache import cacheit +from sympy.core.containers import Tuple +from sympy.core.exprtools import factor_terms +from sympy.core.function import (expand, expand_mul, expand_power_base, + expand_trig, Function) +from sympy.core.mul import Mul +from sympy.core.intfunc import ilcm +from sympy.core.numbers import Rational, pi +from sympy.core.relational import Eq, Ne, _canonical_coeff +from sympy.core.sorting import default_sort_key, ordered +from sympy.core.symbol import Dummy, symbols, Wild, Symbol +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.complexes import (re, im, arg, Abs, sign, + unpolarify, polarify, polar_lift, principal_branch, unbranched_argument, + periodic_argument) +from sympy.functions.elementary.exponential import exp, exp_polar, log +from sympy.functions.elementary.integers import ceiling +from sympy.functions.elementary.hyperbolic import (cosh, sinh, + _rewrite_hyperbolics_as_exp, HyperbolicFunction) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise, piecewise_fold +from sympy.functions.elementary.trigonometric import (cos, sin, sinc, + TrigonometricFunction) +from sympy.functions.special.bessel import besselj, bessely, besseli, besselk +from sympy.functions.special.delta_functions import DiracDelta, Heaviside +from sympy.functions.special.elliptic_integrals import elliptic_k, elliptic_e +from sympy.functions.special.error_functions import (erf, erfc, erfi, Ei, + expint, Si, Ci, Shi, Chi, fresnels, fresnelc) +from sympy.functions.special.gamma_functions import gamma +from sympy.functions.special.hyper import hyper, meijerg +from sympy.functions.special.singularity_functions import SingularityFunction +from .integrals import Integral +from sympy.logic.boolalg import And, Or, BooleanAtom, Not, BooleanFunction +from sympy.polys import cancel, factor +from sympy.utilities.iterables import multiset_partitions +from sympy.utilities.misc import debug as _debug +from sympy.utilities.misc import debugf as _debugf + +# keep this at top for easy reference +z = Dummy('z') + + +def _has(res, *f): + # return True if res has f; in the case of Piecewise + # only return True if *all* pieces have f + res = piecewise_fold(res) + if getattr(res, 'is_Piecewise', False): + return all(_has(i, *f) for i in res.args) + return res.has(*f) + + +def _create_lookup_table(table): + """ Add formulae for the function -> meijerg lookup table. """ + def wild(n): + return Wild(n, exclude=[z]) + p, q, a, b, c = list(map(wild, 'pqabc')) + n = Wild('n', properties=[lambda x: x.is_Integer and x > 0]) + t = p*z**q + + def add(formula, an, ap, bm, bq, arg=t, fac=S.One, cond=True, hint=True): + table.setdefault(_mytype(formula, z), []).append((formula, + [(fac, meijerg(an, ap, bm, bq, arg))], cond, hint)) + + def addi(formula, inst, cond, hint=True): + table.setdefault( + _mytype(formula, z), []).append((formula, inst, cond, hint)) + + def constant(a): + return [(a, meijerg([1], [], [], [0], z)), + (a, meijerg([], [1], [0], [], z))] + table[()] = [(a, constant(a), True, True)] + + # [P], Section 8. + class IsNonPositiveInteger(Function): + + @classmethod + def eval(cls, arg): + arg = unpolarify(arg) + if arg.is_Integer is True: + return arg <= 0 + + # Section 8.4.2 + # TODO this needs more polar_lift (c/f entry for exp) + add(Heaviside(t - b)*(t - b)**(a - 1), [a], [], [], [0], t/b, + gamma(a)*b**(a - 1), And(b > 0)) + add(Heaviside(b - t)*(b - t)**(a - 1), [], [a], [0], [], t/b, + gamma(a)*b**(a - 1), And(b > 0)) + add(Heaviside(z - (b/p)**(1/q))*(t - b)**(a - 1), [a], [], [], [0], t/b, + gamma(a)*b**(a - 1), And(b > 0)) + add(Heaviside((b/p)**(1/q) - z)*(b - t)**(a - 1), [], [a], [0], [], t/b, + gamma(a)*b**(a - 1), And(b > 0)) + add((b + t)**(-a), [1 - a], [], [0], [], t/b, b**(-a)/gamma(a), + hint=Not(IsNonPositiveInteger(a))) + add(Abs(b - t)**(-a), [1 - a], [(1 - a)/2], [0], [(1 - a)/2], t/b, + 2*sin(pi*a/2)*gamma(1 - a)*Abs(b)**(-a), re(a) < 1) + add((t**a - b**a)/(t - b), [0, a], [], [0, a], [], t/b, + b**(a - 1)*sin(a*pi)/pi) + + # 12 + def A1(r, sign, nu): + return pi**Rational(-1, 2)*(-sign*nu/2)**(1 - 2*r) + + def tmpadd(r, sgn): + # XXX the a**2 is bad for matching + add((sqrt(a**2 + t) + sgn*a)**b/(a**2 + t)**r, + [(1 + b)/2, 1 - 2*r + b/2], [], + [(b - sgn*b)/2], [(b + sgn*b)/2], t/a**2, + a**(b - 2*r)*A1(r, sgn, b)) + tmpadd(0, 1) + tmpadd(0, -1) + tmpadd(S.Half, 1) + tmpadd(S.Half, -1) + + # 13 + def tmpadd(r, sgn): + add((sqrt(a + p*z**q) + sgn*sqrt(p)*z**(q/2))**b/(a + p*z**q)**r, + [1 - r + sgn*b/2], [1 - r - sgn*b/2], [0, S.Half], [], + p*z**q/a, a**(b/2 - r)*A1(r, sgn, b)) + tmpadd(0, 1) + tmpadd(0, -1) + tmpadd(S.Half, 1) + tmpadd(S.Half, -1) + # (those after look obscure) + + # Section 8.4.3 + add(exp(polar_lift(-1)*t), [], [], [0], []) + + # TODO can do sin^n, sinh^n by expansion ... where? + # 8.4.4 (hyperbolic functions) + add(sinh(t), [], [1], [S.Half], [1, 0], t**2/4, pi**Rational(3, 2)) + add(cosh(t), [], [S.Half], [0], [S.Half, S.Half], t**2/4, pi**Rational(3, 2)) + + # Section 8.4.5 + # TODO can do t + a. but can also do by expansion... (XXX not really) + add(sin(t), [], [], [S.Half], [0], t**2/4, sqrt(pi)) + add(cos(t), [], [], [0], [S.Half], t**2/4, sqrt(pi)) + + # Section 8.4.6 (sinc function) + add(sinc(t), [], [], [0], [Rational(-1, 2)], t**2/4, sqrt(pi)/2) + + # Section 8.5.5 + def make_log1(subs): + N = subs[n] + return [(S.NegativeOne**N*factorial(N), + meijerg([], [1]*(N + 1), [0]*(N + 1), [], t))] + + def make_log2(subs): + N = subs[n] + return [(factorial(N), + meijerg([1]*(N + 1), [], [], [0]*(N + 1), t))] + # TODO these only hold for positive p, and can be made more general + # but who uses log(x)*Heaviside(a-x) anyway ... + # TODO also it would be nice to derive them recursively ... + addi(log(t)**n*Heaviside(1 - t), make_log1, True) + addi(log(t)**n*Heaviside(t - 1), make_log2, True) + + def make_log3(subs): + return make_log1(subs) + make_log2(subs) + addi(log(t)**n, make_log3, True) + addi(log(t + a), + constant(log(a)) + [(S.One, meijerg([1, 1], [], [1], [0], t/a))], + True) + addi(log(Abs(t - a)), constant(log(Abs(a))) + + [(pi, meijerg([1, 1], [S.Half], [1], [0, S.Half], t/a))], + True) + # TODO log(x)/(x+a) and log(x)/(x-1) can also be done. should they + # be derivable? + # TODO further formulae in this section seem obscure + + # Sections 8.4.9-10 + # TODO + + # Section 8.4.11 + addi(Ei(t), + constant(-S.ImaginaryUnit*pi) + [(S.NegativeOne, meijerg([], [1], [0, 0], [], + t*polar_lift(-1)))], + True) + + # Section 8.4.12 + add(Si(t), [1], [], [S.Half], [0, 0], t**2/4, sqrt(pi)/2) + add(Ci(t), [], [1], [0, 0], [S.Half], t**2/4, -sqrt(pi)/2) + + # Section 8.4.13 + add(Shi(t), [S.Half], [], [0], [Rational(-1, 2), Rational(-1, 2)], polar_lift(-1)*t**2/4, + t*sqrt(pi)/4) + add(Chi(t), [], [S.Half, 1], [0, 0], [S.Half, S.Half], t**2/4, - + pi**S('3/2')/2) + + # generalized exponential integral + add(expint(a, t), [], [a], [a - 1, 0], [], t) + + # Section 8.4.14 + add(erf(t), [1], [], [S.Half], [0], t**2, 1/sqrt(pi)) + # TODO exp(-x)*erf(I*x) does not work + add(erfc(t), [], [1], [0, S.Half], [], t**2, 1/sqrt(pi)) + # This formula for erfi(z) yields a wrong(?) minus sign + #add(erfi(t), [1], [], [S.Half], [0], -t**2, I/sqrt(pi)) + add(erfi(t), [S.Half], [], [0], [Rational(-1, 2)], -t**2, t/sqrt(pi)) + + # Fresnel Integrals + add(fresnels(t), [1], [], [Rational(3, 4)], [0, Rational(1, 4)], pi**2*t**4/16, S.Half) + add(fresnelc(t), [1], [], [Rational(1, 4)], [0, Rational(3, 4)], pi**2*t**4/16, S.Half) + + ##### bessel-type functions ##### + # Section 8.4.19 + add(besselj(a, t), [], [], [a/2], [-a/2], t**2/4) + + # all of the following are derivable + #add(sin(t)*besselj(a, t), [Rational(1, 4), Rational(3, 4)], [], [(1+a)/2], + # [-a/2, a/2, (1-a)/2], t**2, 1/sqrt(2)) + #add(cos(t)*besselj(a, t), [Rational(1, 4), Rational(3, 4)], [], [a/2], + # [-a/2, (1+a)/2, (1-a)/2], t**2, 1/sqrt(2)) + #add(besselj(a, t)**2, [S.Half], [], [a], [-a, 0], t**2, 1/sqrt(pi)) + #add(besselj(a, t)*besselj(b, t), [0, S.Half], [], [(a + b)/2], + # [-(a+b)/2, (a - b)/2, (b - a)/2], t**2, 1/sqrt(pi)) + + # Section 8.4.20 + add(bessely(a, t), [], [-(a + 1)/2], [a/2, -a/2], [-(a + 1)/2], t**2/4) + + # TODO all of the following should be derivable + #add(sin(t)*bessely(a, t), [Rational(1, 4), Rational(3, 4)], [(1 - a - 1)/2], + # [(1 + a)/2, (1 - a)/2], [(1 - a - 1)/2, (1 - 1 - a)/2, (1 - 1 + a)/2], + # t**2, 1/sqrt(2)) + #add(cos(t)*bessely(a, t), [Rational(1, 4), Rational(3, 4)], [(0 - a - 1)/2], + # [(0 + a)/2, (0 - a)/2], [(0 - a - 1)/2, (1 - 0 - a)/2, (1 - 0 + a)/2], + # t**2, 1/sqrt(2)) + #add(besselj(a, t)*bessely(b, t), [0, S.Half], [(a - b - 1)/2], + # [(a + b)/2, (a - b)/2], [(a - b - 1)/2, -(a + b)/2, (b - a)/2], + # t**2, 1/sqrt(pi)) + #addi(bessely(a, t)**2, + # [(2/sqrt(pi), meijerg([], [S.Half, S.Half - a], [0, a, -a], + # [S.Half - a], t**2)), + # (1/sqrt(pi), meijerg([S.Half], [], [a], [-a, 0], t**2))], + # True) + #addi(bessely(a, t)*bessely(b, t), + # [(2/sqrt(pi), meijerg([], [0, S.Half, (1 - a - b)/2], + # [(a + b)/2, (a - b)/2, (b - a)/2, -(a + b)/2], + # [(1 - a - b)/2], t**2)), + # (1/sqrt(pi), meijerg([0, S.Half], [], [(a + b)/2], + # [-(a + b)/2, (a - b)/2, (b - a)/2], t**2))], + # True) + + # Section 8.4.21 ? + # Section 8.4.22 + add(besseli(a, t), [], [(1 + a)/2], [a/2], [-a/2, (1 + a)/2], t**2/4, pi) + # TODO many more formulas. should all be derivable + + # Section 8.4.23 + add(besselk(a, t), [], [], [a/2, -a/2], [], t**2/4, S.Half) + # TODO many more formulas. should all be derivable + + # Complete elliptic integrals K(z) and E(z) + add(elliptic_k(t), [S.Half, S.Half], [], [0], [0], -t, S.Half) + add(elliptic_e(t), [S.Half, 3*S.Half], [], [0], [0], -t, Rational(-1, 2)/2) + + +#################################################################### +# First some helper functions. +#################################################################### + +from sympy.utilities.timeutils import timethis +timeit = timethis('meijerg') + + +def _mytype(f: Basic, x: Symbol) -> tuple[type[Basic], ...]: + """ Create a hashable entity describing the type of f. """ + def key(x: type[Basic]) -> tuple[int, int, str]: + return x.class_key() + + if x not in f.free_symbols: + return () + elif f.is_Function: + return type(f), + return tuple(sorted((t for a in f.args for t in _mytype(a, x)), key=key)) + + +class _CoeffExpValueError(ValueError): + """ + Exception raised by _get_coeff_exp, for internal use only. + """ + pass + + +def _get_coeff_exp(expr, x): + """ + When expr is known to be of the form c*x**b, with c and/or b possibly 1, + return c, b. + + Examples + ======== + + >>> from sympy.abc import x, a, b + >>> from sympy.integrals.meijerint import _get_coeff_exp + >>> _get_coeff_exp(a*x**b, x) + (a, b) + >>> _get_coeff_exp(x, x) + (1, 1) + >>> _get_coeff_exp(2*x, x) + (2, 1) + >>> _get_coeff_exp(x**3, x) + (1, 3) + """ + from sympy.simplify import powsimp + (c, m) = expand_power_base(powsimp(expr)).as_coeff_mul(x) + if not m: + return c, S.Zero + [m] = m + if m.is_Pow: + if m.base != x: + raise _CoeffExpValueError('expr not of form a*x**b') + return c, m.exp + elif m == x: + return c, S.One + else: + raise _CoeffExpValueError('expr not of form a*x**b: %s' % expr) + + +def _exponents(expr, x): + """ + Find the exponents of ``x`` (not including zero) in ``expr``. + + Examples + ======== + + >>> from sympy.integrals.meijerint import _exponents + >>> from sympy.abc import x, y + >>> from sympy import sin + >>> _exponents(x, x) + {1} + >>> _exponents(x**2, x) + {2} + >>> _exponents(x**2 + x, x) + {1, 2} + >>> _exponents(x**3*sin(x + x**y) + 1/x, x) + {-1, 1, 3, y} + """ + def _exponents_(expr, x, res): + if expr == x: + res.update([1]) + return + if expr.is_Pow and expr.base == x: + res.update([expr.exp]) + return + for argument in expr.args: + _exponents_(argument, x, res) + res = set() + _exponents_(expr, x, res) + return res + + +def _functions(expr, x): + """ Find the types of functions in expr, to estimate the complexity. """ + return {e.func for e in expr.atoms(Function) if x in e.free_symbols} + + +def _find_splitting_points(expr, x): + """ + Find numbers a such that a linear substitution x -> x + a would + (hopefully) simplify expr. + + Examples + ======== + + >>> from sympy.integrals.meijerint import _find_splitting_points as fsp + >>> from sympy import sin + >>> from sympy.abc import x + >>> fsp(x, x) + {0} + >>> fsp((x-1)**3, x) + {1} + >>> fsp(sin(x+3)*x, x) + {-3, 0} + """ + p, q = [Wild(n, exclude=[x]) for n in 'pq'] + + def compute_innermost(expr, res): + if not isinstance(expr, Expr): + return + m = expr.match(p*x + q) + if m and m[p] != 0: + res.add(-m[q]/m[p]) + return + if expr.is_Atom: + return + for argument in expr.args: + compute_innermost(argument, res) + innermost = set() + compute_innermost(expr, innermost) + return innermost + + +def _split_mul(f, x): + """ + Split expression ``f`` into fac, po, g, where fac is a constant factor, + po = x**s for some s independent of s, and g is "the rest". + + Examples + ======== + + >>> from sympy.integrals.meijerint import _split_mul + >>> from sympy import sin + >>> from sympy.abc import s, x + >>> _split_mul((3*x)**s*sin(x**2)*x, x) + (3**s, x*x**s, sin(x**2)) + """ + fac = S.One + po = S.One + g = S.One + f = expand_power_base(f) + + args = Mul.make_args(f) + for a in args: + if a == x: + po *= x + elif x not in a.free_symbols: + fac *= a + else: + if a.is_Pow and x not in a.exp.free_symbols: + c, t = a.base.as_coeff_mul(x) + if t != (x,): + c, t = expand_mul(a.base).as_coeff_mul(x) + if t == (x,): + po *= x**a.exp + fac *= unpolarify(polarify(c**a.exp, subs=False)) + continue + g *= a + + return fac, po, g + + +def _mul_args(f): + """ + Return a list ``L`` such that ``Mul(*L) == f``. + + If ``f`` is not a ``Mul`` or ``Pow``, ``L=[f]``. + If ``f=g**n`` for an integer ``n``, ``L=[g]*n``. + If ``f`` is a ``Mul``, ``L`` comes from applying ``_mul_args`` to all factors of ``f``. + """ + args = Mul.make_args(f) + gs = [] + for g in args: + if g.is_Pow and g.exp.is_Integer: + n = g.exp + base = g.base + if n < 0: + n = -n + base = 1/base + gs += [base]*n + else: + gs.append(g) + return gs + + +def _mul_as_two_parts(f): + """ + Find all the ways to split ``f`` into a product of two terms. + Return None on failure. + + Explanation + =========== + + Although the order is canonical from multiset_partitions, this is + not necessarily the best order to process the terms. For example, + if the case of len(gs) == 2 is removed and multiset is allowed to + sort the terms, some tests fail. + + Examples + ======== + + >>> from sympy.integrals.meijerint import _mul_as_two_parts + >>> from sympy import sin, exp, ordered + >>> from sympy.abc import x + >>> list(ordered(_mul_as_two_parts(x*sin(x)*exp(x)))) + [(x, exp(x)*sin(x)), (x*exp(x), sin(x)), (x*sin(x), exp(x))] + """ + + gs = _mul_args(f) + if len(gs) < 2: + return None + if len(gs) == 2: + return [tuple(gs)] + return [(Mul(*x), Mul(*y)) for (x, y) in multiset_partitions(gs, 2)] + + +def _inflate_g(g, n): + """ Return C, h such that h is a G function of argument z**n and + g = C*h. """ + # TODO should this be a method of meijerg? + # See: [L, page 150, equation (5)] + def inflate(params, n): + """ (a1, .., ak) -> (a1/n, (a1+1)/n, ..., (ak + n-1)/n) """ + return [(a + i)/n for a, i in itertools.product(params, range(n))] + v = S(len(g.ap) - len(g.bq)) + C = n**(1 + g.nu + v/2) + C /= (2*pi)**((n - 1)*g.delta) + return C, meijerg(inflate(g.an, n), inflate(g.aother, n), + inflate(g.bm, n), inflate(g.bother, n), + g.argument**n * n**(n*v)) + + +def _flip_g(g): + """ Turn the G function into one of inverse argument + (i.e. G(1/x) -> G'(x)) """ + # See [L], section 5.2 + def tr(l): + return [1 - a for a in l] + return meijerg(tr(g.bm), tr(g.bother), tr(g.an), tr(g.aother), 1/g.argument) + + +def _inflate_fox_h(g, a): + r""" + Let d denote the integrand in the definition of the G function ``g``. + Consider the function H which is defined in the same way, but with + integrand d/Gamma(a*s) (contour conventions as usual). + + If ``a`` is rational, the function H can be written as C*G, for a constant C + and a G-function G. + + This function returns C, G. + """ + if a < 0: + return _inflate_fox_h(_flip_g(g), -a) + p = S(a.p) + q = S(a.q) + # We use the substitution s->qs, i.e. inflate g by q. We are left with an + # extra factor of Gamma(p*s), for which we use Gauss' multiplication + # theorem. + D, g = _inflate_g(g, q) + z = g.argument + D /= (2*pi)**((1 - p)/2)*p**Rational(-1, 2) + z /= p**p + bs = [(n + 1)/p for n in range(p)] + return D, meijerg(g.an, g.aother, g.bm, list(g.bother) + bs, z) + + +_dummies: dict[tuple[str, str], Dummy] = {} + + +def _dummy(name, token, expr, **kwargs): + """ + Return a dummy. This will return the same dummy if the same token+name is + requested more than once, and it is not already in expr. + This is for being cache-friendly. + """ + d = _dummy_(name, token, **kwargs) + if d in expr.free_symbols: + return Dummy(name, **kwargs) + return d + + +def _dummy_(name, token, **kwargs): + """ + Return a dummy associated to name and token. Same effect as declaring + it globally. + """ + if not (name, token) in _dummies: + _dummies[(name, token)] = Dummy(name, **kwargs) + return _dummies[(name, token)] + + +def _is_analytic(f, x): + """ Check if f(x), when expressed using G functions on the positive reals, + will in fact agree with the G functions almost everywhere """ + return not any(x in expr.free_symbols for expr in f.atoms(Heaviside, Abs)) + + +def _condsimp(cond, first=True): + """ + Do naive simplifications on ``cond``. + + Explanation + =========== + + Note that this routine is completely ad-hoc, simplification rules being + added as need arises rather than following any logical pattern. + + Examples + ======== + + >>> from sympy.integrals.meijerint import _condsimp as simp + >>> from sympy import Or, Eq + >>> from sympy.abc import x, y + >>> simp(Or(x < y, Eq(x, y))) + x <= y + """ + if first: + cond = cond.replace(lambda _: _.is_Relational, _canonical_coeff) + first = False + if not isinstance(cond, BooleanFunction): + return cond + p, q, r = symbols('p q r', cls=Wild) + # transforms tests use 0, 4, 5 and 11-14 + # meijer tests use 0, 2, 11, 14 + # joint_rv uses 6, 7 + rules = [ + (Or(p < q, Eq(p, q)), p <= q), # 0 + # The next two obviously are instances of a general pattern, but it is + # easier to spell out the few cases we care about. + (And(Abs(arg(p)) <= pi, Abs(arg(p) - 2*pi) <= pi), + Eq(arg(p) - pi, 0)), # 1 + (And(Abs(2*arg(p) + pi) <= pi, Abs(2*arg(p) - pi) <= pi), + Eq(arg(p), 0)), # 2 + (And(Abs(2*arg(p) + pi) < pi, Abs(2*arg(p) - pi) <= pi), + S.false), # 3 + (And(Abs(arg(p) - pi/2) <= pi/2, Abs(arg(p) + pi/2) <= pi/2), + Eq(arg(p), 0)), # 4 + (And(Abs(arg(p) - pi/2) <= pi/2, Abs(arg(p) + pi/2) < pi/2), + S.false), # 5 + (And(Abs(arg(p**2/2 + 1)) < pi, Ne(Abs(arg(p**2/2 + 1)), pi)), + S.true), # 6 + (Or(Abs(arg(p**2/2 + 1)) < pi, Ne(1/(p**2/2 + 1), 0)), + S.true), # 7 + (And(Abs(unbranched_argument(p)) <= pi, + Abs(unbranched_argument(exp_polar(-2*pi*S.ImaginaryUnit)*p)) <= pi), + Eq(unbranched_argument(exp_polar(-S.ImaginaryUnit*pi)*p), 0)), # 8 + (And(Abs(unbranched_argument(p)) <= pi/2, + Abs(unbranched_argument(exp_polar(-pi*S.ImaginaryUnit)*p)) <= pi/2), + Eq(unbranched_argument(exp_polar(-S.ImaginaryUnit*pi/2)*p), 0)), # 9 + (Or(p <= q, And(p < q, r)), p <= q), # 10 + (Ne(p**2, 1) & (p**2 > 1), p**2 > 1), # 11 + (Ne(1/p, 1) & (cos(Abs(arg(p)))*Abs(p) > 1), Abs(p) > 1), # 12 + (Ne(p, 2) & (cos(Abs(arg(p)))*Abs(p) > 2), Abs(p) > 2), # 13 + ((Abs(arg(p)) < pi/2) & (cos(Abs(arg(p)))*sqrt(Abs(p**2)) > 1), p**2 > 1), # 14 + ] + cond = cond.func(*[_condsimp(_, first) for _ in cond.args]) + change = True + while change: + change = False + for irule, (fro, to) in enumerate(rules): + if fro.func != cond.func: + continue + for n, arg1 in enumerate(cond.args): + if r in fro.args[0].free_symbols: + m = arg1.match(fro.args[1]) + num = 1 + else: + num = 0 + m = arg1.match(fro.args[0]) + if not m: + continue + otherargs = [x.subs(m) for x in fro.args[:num] + fro.args[num + 1:]] + otherlist = [n] + for arg2 in otherargs: + for k, arg3 in enumerate(cond.args): + if k in otherlist: + continue + if arg2 == arg3: + otherlist += [k] + break + if isinstance(arg3, And) and arg2.args[1] == r and \ + isinstance(arg2, And) and arg2.args[0] in arg3.args: + otherlist += [k] + break + if isinstance(arg3, And) and arg2.args[0] == r and \ + isinstance(arg2, And) and arg2.args[1] in arg3.args: + otherlist += [k] + break + if len(otherlist) != len(otherargs) + 1: + continue + newargs = [arg_ for (k, arg_) in enumerate(cond.args) + if k not in otherlist] + [to.subs(m)] + if SYMPY_DEBUG: + if irule not in (0, 2, 4, 5, 6, 7, 11, 12, 13, 14): + print('used new rule:', irule) + cond = cond.func(*newargs) + change = True + break + + # final tweak + def rel_touchup(rel): + if rel.rel_op != '==' or rel.rhs != 0: + return rel + + # handle Eq(*, 0) + LHS = rel.lhs + m = LHS.match(arg(p)**q) + if not m: + m = LHS.match(unbranched_argument(polar_lift(p)**q)) + if not m: + if isinstance(LHS, periodic_argument) and not LHS.args[0].is_polar \ + and LHS.args[1] is S.Infinity: + return (LHS.args[0] > 0) + return rel + return (m[p] > 0) + cond = cond.replace(lambda _: _.is_Relational, rel_touchup) + if SYMPY_DEBUG: + print('_condsimp: ', cond) + return cond + +def _eval_cond(cond): + """ Re-evaluate the conditions. """ + if isinstance(cond, bool): + return cond + return _condsimp(cond.doit()) + +#################################################################### +# Now the "backbone" functions to do actual integration. +#################################################################### + + +def _my_principal_branch(expr, period, full_pb=False): + """ Bring expr nearer to its principal branch by removing superfluous + factors. + This function does *not* guarantee to yield the principal branch, + to avoid introducing opaque principal_branch() objects, + unless full_pb=True. """ + res = principal_branch(expr, period) + if not full_pb: + res = res.replace(principal_branch, lambda x, y: x) + return res + + +def _rewrite_saxena_1(fac, po, g, x): + """ + Rewrite the integral fac*po*g dx, from zero to infinity, as + integral fac*G, where G has argument a*x. Note po=x**s. + Return fac, G. + """ + _, s = _get_coeff_exp(po, x) + a, b = _get_coeff_exp(g.argument, x) + period = g.get_period() + a = _my_principal_branch(a, period) + + # We substitute t = x**b. + C = fac/(Abs(b)*a**((s + 1)/b - 1)) + # Absorb a factor of (at)**((1 + s)/b - 1). + + def tr(l): + return [a + (1 + s)/b - 1 for a in l] + return C, meijerg(tr(g.an), tr(g.aother), tr(g.bm), tr(g.bother), + a*x) + + +def _check_antecedents_1(g, x, helper=False): + r""" + Return a condition under which the mellin transform of g exists. + Any power of x has already been absorbed into the G function, + so this is just $\int_0^\infty g\, dx$. + + See [L, section 5.6.1]. (Note that s=1.) + + If ``helper`` is True, only check if the MT exists at infinity, i.e. if + $\int_1^\infty g\, dx$ exists. + """ + # NOTE if you update these conditions, please update the documentation as well + delta = g.delta + eta, _ = _get_coeff_exp(g.argument, x) + m, n, p, q = S([len(g.bm), len(g.an), len(g.ap), len(g.bq)]) + + if p > q: + def tr(l): + return [1 - x for x in l] + return _check_antecedents_1(meijerg(tr(g.bm), tr(g.bother), + tr(g.an), tr(g.aother), x/eta), + x) + + tmp = [-re(b) < 1 for b in g.bm] + [1 < 1 - re(a) for a in g.an] + cond_3 = And(*tmp) + + tmp += [-re(b) < 1 for b in g.bother] + tmp += [1 < 1 - re(a) for a in g.aother] + cond_3_star = And(*tmp) + + cond_4 = (-re(g.nu) + (q + 1 - p)/2 > q - p) + + def debug(*msg): + _debug(*msg) + + def debugf(string, arg): + _debugf(string, arg) + + debug('Checking antecedents for 1 function:') + debugf(' delta=%s, eta=%s, m=%s, n=%s, p=%s, q=%s', + (delta, eta, m, n, p, q)) + debugf(' ap = %s, %s', (list(g.an), list(g.aother))) + debugf(' bq = %s, %s', (list(g.bm), list(g.bother))) + debugf(' cond_3=%s, cond_3*=%s, cond_4=%s', (cond_3, cond_3_star, cond_4)) + + conds = [] + + # case 1 + case1 = [] + tmp1 = [1 <= n, p < q, 1 <= m] + tmp2 = [1 <= p, 1 <= m, Eq(q, p + 1), Not(And(Eq(n, 0), Eq(m, p + 1)))] + tmp3 = [1 <= p, Eq(q, p)] + for k in range(ceiling(delta/2) + 1): + tmp3 += [Ne(Abs(unbranched_argument(eta)), (delta - 2*k)*pi)] + tmp = [delta > 0, Abs(unbranched_argument(eta)) < delta*pi] + extra = [Ne(eta, 0), cond_3] + if helper: + extra = [] + for t in [tmp1, tmp2, tmp3]: + case1 += [And(*(t + tmp + extra))] + conds += case1 + debug(' case 1:', case1) + + # case 2 + extra = [cond_3] + if helper: + extra = [] + case2 = [And(Eq(n, 0), p + 1 <= m, m <= q, + Abs(unbranched_argument(eta)) < delta*pi, *extra)] + conds += case2 + debug(' case 2:', case2) + + # case 3 + extra = [cond_3, cond_4] + if helper: + extra = [] + case3 = [And(p < q, 1 <= m, delta > 0, Eq(Abs(unbranched_argument(eta)), delta*pi), + *extra)] + case3 += [And(p <= q - 2, Eq(delta, 0), Eq(Abs(unbranched_argument(eta)), 0), *extra)] + conds += case3 + debug(' case 3:', case3) + + # TODO altered cases 4-7 + + # extra case from wofram functions site: + # (reproduced verbatim from Prudnikov, section 2.24.2) + # https://functions.wolfram.com/HypergeometricFunctions/MeijerG/21/02/01/ + case_extra = [] + case_extra += [Eq(p, q), Eq(delta, 0), Eq(unbranched_argument(eta), 0), Ne(eta, 0)] + if not helper: + case_extra += [cond_3] + s = [] + for a, b in zip(g.ap, g.bq): + s += [b - a] + case_extra += [re(Add(*s)) < 0] + case_extra = And(*case_extra) + conds += [case_extra] + debug(' extra case:', [case_extra]) + + case_extra_2 = [And(delta > 0, Abs(unbranched_argument(eta)) < delta*pi)] + if not helper: + case_extra_2 += [cond_3] + case_extra_2 = And(*case_extra_2) + conds += [case_extra_2] + debug(' second extra case:', [case_extra_2]) + + # TODO This leaves only one case from the three listed by Prudnikov. + # Investigate if these indeed cover everything; if so, remove the rest. + + return Or(*conds) + + +def _int0oo_1(g, x): + r""" + Evaluate $\int_0^\infty g\, dx$ using G functions, + assuming the necessary conditions are fulfilled. + + Examples + ======== + + >>> from sympy.abc import a, b, c, d, x, y + >>> from sympy import meijerg + >>> from sympy.integrals.meijerint import _int0oo_1 + >>> _int0oo_1(meijerg([a], [b], [c], [d], x*y), x) + gamma(-a)*gamma(c + 1)/(y*gamma(-d)*gamma(b + 1)) + """ + from sympy.simplify import gammasimp + # See [L, section 5.6.1]. Note that s=1. + eta, _ = _get_coeff_exp(g.argument, x) + res = 1/eta + # XXX TODO we should reduce order first + for b in g.bm: + res *= gamma(b + 1) + for a in g.an: + res *= gamma(1 - a - 1) + for b in g.bother: + res /= gamma(1 - b - 1) + for a in g.aother: + res /= gamma(a + 1) + return gammasimp(unpolarify(res)) + + +def _rewrite_saxena(fac, po, g1, g2, x, full_pb=False): + """ + Rewrite the integral ``fac*po*g1*g2`` from 0 to oo in terms of G + functions with argument ``c*x``. + + Explanation + =========== + + Return C, f1, f2 such that integral C f1 f2 from 0 to infinity equals + integral fac ``po``, ``g1``, ``g2`` from 0 to infinity. + + Examples + ======== + + >>> from sympy.integrals.meijerint import _rewrite_saxena + >>> from sympy.abc import s, t, m + >>> from sympy import meijerg + >>> g1 = meijerg([], [], [0], [], s*t) + >>> g2 = meijerg([], [], [m/2], [-m/2], t**2/4) + >>> r = _rewrite_saxena(1, t**0, g1, g2, t) + >>> r[0] + s/(4*sqrt(pi)) + >>> r[1] + meijerg(((), ()), ((-1/2, 0), ()), s**2*t/4) + >>> r[2] + meijerg(((), ()), ((m/2,), (-m/2,)), t/4) + """ + def pb(g): + a, b = _get_coeff_exp(g.argument, x) + per = g.get_period() + return meijerg(g.an, g.aother, g.bm, g.bother, + _my_principal_branch(a, per, full_pb)*x**b) + + _, s = _get_coeff_exp(po, x) + _, b1 = _get_coeff_exp(g1.argument, x) + _, b2 = _get_coeff_exp(g2.argument, x) + if (b1 < 0) == True: + b1 = -b1 + g1 = _flip_g(g1) + if (b2 < 0) == True: + b2 = -b2 + g2 = _flip_g(g2) + if not b1.is_Rational or not b2.is_Rational: + return + m1, n1 = b1.p, b1.q + m2, n2 = b2.p, b2.q + tau = ilcm(m1*n2, m2*n1) + r1 = tau//(m1*n2) + r2 = tau//(m2*n1) + + C1, g1 = _inflate_g(g1, r1) + C2, g2 = _inflate_g(g2, r2) + g1 = pb(g1) + g2 = pb(g2) + + fac *= C1*C2 + a1, b = _get_coeff_exp(g1.argument, x) + a2, _ = _get_coeff_exp(g2.argument, x) + + # arbitrarily tack on the x**s part to g1 + # TODO should we try both? + exp = (s + 1)/b - 1 + fac = fac/(Abs(b) * a1**exp) + + def tr(l): + return [a + exp for a in l] + g1 = meijerg(tr(g1.an), tr(g1.aother), tr(g1.bm), tr(g1.bother), a1*x) + g2 = meijerg(g2.an, g2.aother, g2.bm, g2.bother, a2*x) + + from sympy.simplify import powdenest + return powdenest(fac, polar=True), g1, g2 + + +def _check_antecedents(g1, g2, x): + """ Return a condition under which the integral theorem applies. """ + # Yes, this is madness. + # XXX TODO this is a testing *nightmare* + # NOTE if you update these conditions, please update the documentation as well + + # The following conditions are found in + # [P], Section 2.24.1 + # + # They are also reproduced (verbatim!) at + # https://functions.wolfram.com/HypergeometricFunctions/MeijerG/21/02/03/ + # + # Note: k=l=r=alpha=1 + sigma, _ = _get_coeff_exp(g1.argument, x) + omega, _ = _get_coeff_exp(g2.argument, x) + s, t, u, v = S([len(g1.bm), len(g1.an), len(g1.ap), len(g1.bq)]) + m, n, p, q = S([len(g2.bm), len(g2.an), len(g2.ap), len(g2.bq)]) + bstar = s + t - (u + v)/2 + cstar = m + n - (p + q)/2 + rho = g1.nu + (u - v)/2 + 1 + mu = g2.nu + (p - q)/2 + 1 + phi = q - p - (v - u) + eta = 1 - (v - u) - mu - rho + psi = (pi*(q - m - n) + Abs(unbranched_argument(omega)))/(q - p) + theta = (pi*(v - s - t) + Abs(unbranched_argument(sigma)))/(v - u) + + _debug('Checking antecedents:') + _debugf(' sigma=%s, s=%s, t=%s, u=%s, v=%s, b*=%s, rho=%s', + (sigma, s, t, u, v, bstar, rho)) + _debugf(' omega=%s, m=%s, n=%s, p=%s, q=%s, c*=%s, mu=%s,', + (omega, m, n, p, q, cstar, mu)) + _debugf(' phi=%s, eta=%s, psi=%s, theta=%s', (phi, eta, psi, theta)) + + def _c1(): + for g in [g1, g2]: + for i, j in itertools.product(g.an, g.bm): + diff = i - j + if diff.is_integer and diff.is_positive: + return False + return True + c1 = _c1() + c2 = And(*[re(1 + i + j) > 0 for i in g1.bm for j in g2.bm]) + c3 = And(*[re(1 + i + j) < 1 + 1 for i in g1.an for j in g2.an]) + c4 = And(*[(p - q)*re(1 + i - 1) - re(mu) > Rational(-3, 2) for i in g1.an]) + c5 = And(*[(p - q)*re(1 + i) - re(mu) > Rational(-3, 2) for i in g1.bm]) + c6 = And(*[(u - v)*re(1 + i - 1) - re(rho) > Rational(-3, 2) for i in g2.an]) + c7 = And(*[(u - v)*re(1 + i) - re(rho) > Rational(-3, 2) for i in g2.bm]) + c8 = (Abs(phi) + 2*re((rho - 1)*(q - p) + (v - u)*(q - p) + (mu - + 1)*(v - u)) > 0) + c9 = (Abs(phi) - 2*re((rho - 1)*(q - p) + (v - u)*(q - p) + (mu - + 1)*(v - u)) > 0) + c10 = (Abs(unbranched_argument(sigma)) < bstar*pi) + c11 = Eq(Abs(unbranched_argument(sigma)), bstar*pi) + c12 = (Abs(unbranched_argument(omega)) < cstar*pi) + c13 = Eq(Abs(unbranched_argument(omega)), cstar*pi) + + # The following condition is *not* implemented as stated on the wolfram + # function site. In the book of Prudnikov there is an additional part + # (the And involving re()). However, I only have this book in russian, and + # I don't read any russian. The following condition is what other people + # have told me it means. + # Worryingly, it is different from the condition implemented in REDUCE. + # The REDUCE implementation: + # https://reduce-algebra.svn.sourceforge.net/svnroot/reduce-algebra/trunk/packages/defint/definta.red + # (search for tst14) + # The Wolfram alpha version: + # https://functions.wolfram.com/HypergeometricFunctions/MeijerG/21/02/03/03/0014/ + z0 = exp(-(bstar + cstar)*pi*S.ImaginaryUnit) + zos = unpolarify(z0*omega/sigma) + zso = unpolarify(z0*sigma/omega) + if zos == 1/zso: + c14 = And(Eq(phi, 0), bstar + cstar <= 1, + Or(Ne(zos, 1), re(mu + rho + v - u) < 1, + re(mu + rho + q - p) < 1)) + else: + def _cond(z): + '''Returns True if abs(arg(1-z)) < pi, avoiding arg(0). + + Explanation + =========== + + If ``z`` is 1 then arg is NaN. This raises a + TypeError on `NaN < pi`. Previously this gave `False` so + this behavior has been hardcoded here but someone should + check if this NaN is more serious! This NaN is triggered by + test_meijerint() in test_meijerint.py: + `meijerint_definite(exp(x), x, 0, I)` + ''' + return z != 1 and Abs(arg(1 - z)) < pi + + c14 = And(Eq(phi, 0), bstar - 1 + cstar <= 0, + Or(And(Ne(zos, 1), _cond(zos)), + And(re(mu + rho + v - u) < 1, Eq(zos, 1)))) + + c14_alt = And(Eq(phi, 0), cstar - 1 + bstar <= 0, + Or(And(Ne(zso, 1), _cond(zso)), + And(re(mu + rho + q - p) < 1, Eq(zso, 1)))) + + # Since r=k=l=1, in our case there is c14_alt which is the same as calling + # us with (g1, g2) = (g2, g1). The conditions below enumerate all cases + # (i.e. we don't have to try arguments reversed by hand), and indeed try + # all symmetric cases. (i.e. whenever there is a condition involving c14, + # there is also a dual condition which is exactly what we would get when g1, + # g2 were interchanged, *but c14 was unaltered*). + # Hence the following seems correct: + c14 = Or(c14, c14_alt) + + ''' + When `c15` is NaN (e.g. from `psi` being NaN as happens during + 'test_issue_4992' and/or `theta` is NaN as in 'test_issue_6253', + both in `test_integrals.py`) the comparison to 0 formerly gave False + whereas now an error is raised. To keep the old behavior, the value + of NaN is replaced with False but perhaps a closer look at this condition + should be made: XXX how should conditions leading to c15=NaN be handled? + ''' + try: + lambda_c = (q - p)*Abs(omega)**(1/(q - p))*cos(psi) \ + + (v - u)*Abs(sigma)**(1/(v - u))*cos(theta) + # the TypeError might be raised here, e.g. if lambda_c is NaN + if _eval_cond(lambda_c > 0) != False: + c15 = (lambda_c > 0) + else: + def lambda_s0(c1, c2): + return c1*(q - p)*Abs(omega)**(1/(q - p))*sin(psi) \ + + c2*(v - u)*Abs(sigma)**(1/(v - u))*sin(theta) + lambda_s = Piecewise( + ((lambda_s0(+1, +1)*lambda_s0(-1, -1)), + And(Eq(unbranched_argument(sigma), 0), Eq(unbranched_argument(omega), 0))), + (lambda_s0(sign(unbranched_argument(omega)), +1)*lambda_s0(sign(unbranched_argument(omega)), -1), + And(Eq(unbranched_argument(sigma), 0), Ne(unbranched_argument(omega), 0))), + (lambda_s0(+1, sign(unbranched_argument(sigma)))*lambda_s0(-1, sign(unbranched_argument(sigma))), + And(Ne(unbranched_argument(sigma), 0), Eq(unbranched_argument(omega), 0))), + (lambda_s0(sign(unbranched_argument(omega)), sign(unbranched_argument(sigma))), True)) + tmp = [lambda_c > 0, + And(Eq(lambda_c, 0), Ne(lambda_s, 0), re(eta) > -1), + And(Eq(lambda_c, 0), Eq(lambda_s, 0), re(eta) > 0)] + c15 = Or(*tmp) + except TypeError: + c15 = False + for cond, i in [(c1, 1), (c2, 2), (c3, 3), (c4, 4), (c5, 5), (c6, 6), + (c7, 7), (c8, 8), (c9, 9), (c10, 10), (c11, 11), + (c12, 12), (c13, 13), (c14, 14), (c15, 15)]: + _debugf(' c%s: %s', (i, cond)) + + # We will return Or(*conds) + conds = [] + + def pr(count): + _debugf(' case %s: %s', (count, conds[-1])) + conds += [And(m*n*s*t != 0, bstar.is_positive is True, cstar.is_positive is True, c1, c2, c3, c10, + c12)] # 1 + pr(1) + conds += [And(Eq(u, v), Eq(bstar, 0), cstar.is_positive is True, sigma.is_positive is True, re(rho) < 1, + c1, c2, c3, c12)] # 2 + pr(2) + conds += [And(Eq(p, q), Eq(cstar, 0), bstar.is_positive is True, omega.is_positive is True, re(mu) < 1, + c1, c2, c3, c10)] # 3 + pr(3) + conds += [And(Eq(p, q), Eq(u, v), Eq(bstar, 0), Eq(cstar, 0), + sigma.is_positive is True, omega.is_positive is True, re(mu) < 1, re(rho) < 1, + Ne(sigma, omega), c1, c2, c3)] # 4 + pr(4) + conds += [And(Eq(p, q), Eq(u, v), Eq(bstar, 0), Eq(cstar, 0), + sigma.is_positive is True, omega.is_positive is True, re(mu + rho) < 1, + Ne(omega, sigma), c1, c2, c3)] # 5 + pr(5) + conds += [And(p > q, s.is_positive is True, bstar.is_positive is True, cstar >= 0, + c1, c2, c3, c5, c10, c13)] # 6 + pr(6) + conds += [And(p < q, t.is_positive is True, bstar.is_positive is True, cstar >= 0, + c1, c2, c3, c4, c10, c13)] # 7 + pr(7) + conds += [And(u > v, m.is_positive is True, cstar.is_positive is True, bstar >= 0, + c1, c2, c3, c7, c11, c12)] # 8 + pr(8) + conds += [And(u < v, n.is_positive is True, cstar.is_positive is True, bstar >= 0, + c1, c2, c3, c6, c11, c12)] # 9 + pr(9) + conds += [And(p > q, Eq(u, v), Eq(bstar, 0), cstar >= 0, sigma.is_positive is True, + re(rho) < 1, c1, c2, c3, c5, c13)] # 10 + pr(10) + conds += [And(p < q, Eq(u, v), Eq(bstar, 0), cstar >= 0, sigma.is_positive is True, + re(rho) < 1, c1, c2, c3, c4, c13)] # 11 + pr(11) + conds += [And(Eq(p, q), u > v, bstar >= 0, Eq(cstar, 0), omega.is_positive is True, + re(mu) < 1, c1, c2, c3, c7, c11)] # 12 + pr(12) + conds += [And(Eq(p, q), u < v, bstar >= 0, Eq(cstar, 0), omega.is_positive is True, + re(mu) < 1, c1, c2, c3, c6, c11)] # 13 + pr(13) + conds += [And(p < q, u > v, bstar >= 0, cstar >= 0, + c1, c2, c3, c4, c7, c11, c13)] # 14 + pr(14) + conds += [And(p > q, u < v, bstar >= 0, cstar >= 0, + c1, c2, c3, c5, c6, c11, c13)] # 15 + pr(15) + conds += [And(p > q, u > v, bstar >= 0, cstar >= 0, + c1, c2, c3, c5, c7, c8, c11, c13, c14)] # 16 + pr(16) + conds += [And(p < q, u < v, bstar >= 0, cstar >= 0, + c1, c2, c3, c4, c6, c9, c11, c13, c14)] # 17 + pr(17) + conds += [And(Eq(t, 0), s.is_positive is True, bstar.is_positive is True, phi.is_positive is True, c1, c2, c10)] # 18 + pr(18) + conds += [And(Eq(s, 0), t.is_positive is True, bstar.is_positive is True, phi.is_negative is True, c1, c3, c10)] # 19 + pr(19) + conds += [And(Eq(n, 0), m.is_positive is True, cstar.is_positive is True, phi.is_negative is True, c1, c2, c12)] # 20 + pr(20) + conds += [And(Eq(m, 0), n.is_positive is True, cstar.is_positive is True, phi.is_positive is True, c1, c3, c12)] # 21 + pr(21) + conds += [And(Eq(s*t, 0), bstar.is_positive is True, cstar.is_positive is True, + c1, c2, c3, c10, c12)] # 22 + pr(22) + conds += [And(Eq(m*n, 0), bstar.is_positive is True, cstar.is_positive is True, + c1, c2, c3, c10, c12)] # 23 + pr(23) + + # The following case is from [Luke1969]. As far as I can tell, it is *not* + # covered by Prudnikov's. + # Let G1 and G2 be the two G-functions. Suppose the integral exists from + # 0 to a > 0 (this is easy the easy part), that G1 is exponential decay at + # infinity, and that the mellin transform of G2 exists. + # Then the integral exists. + mt1_exists = _check_antecedents_1(g1, x, helper=True) + mt2_exists = _check_antecedents_1(g2, x, helper=True) + conds += [And(mt2_exists, Eq(t, 0), u < s, bstar.is_positive is True, c10, c1, c2, c3)] + pr('E1') + conds += [And(mt2_exists, Eq(s, 0), v < t, bstar.is_positive is True, c10, c1, c2, c3)] + pr('E2') + conds += [And(mt1_exists, Eq(n, 0), p < m, cstar.is_positive is True, c12, c1, c2, c3)] + pr('E3') + conds += [And(mt1_exists, Eq(m, 0), q < n, cstar.is_positive is True, c12, c1, c2, c3)] + pr('E4') + + # Let's short-circuit if this worked ... + # the rest is corner-cases and terrible to read. + r = Or(*conds) + if _eval_cond(r) != False: + return r + + conds += [And(m + n > p, Eq(t, 0), Eq(phi, 0), s.is_positive is True, bstar.is_positive is True, cstar.is_negative is True, + Abs(unbranched_argument(omega)) < (m + n - p + 1)*pi, + c1, c2, c10, c14, c15)] # 24 + pr(24) + conds += [And(m + n > q, Eq(s, 0), Eq(phi, 0), t.is_positive is True, bstar.is_positive is True, cstar.is_negative is True, + Abs(unbranched_argument(omega)) < (m + n - q + 1)*pi, + c1, c3, c10, c14, c15)] # 25 + pr(25) + conds += [And(Eq(p, q - 1), Eq(t, 0), Eq(phi, 0), s.is_positive is True, bstar.is_positive is True, + cstar >= 0, cstar*pi < Abs(unbranched_argument(omega)), + c1, c2, c10, c14, c15)] # 26 + pr(26) + conds += [And(Eq(p, q + 1), Eq(s, 0), Eq(phi, 0), t.is_positive is True, bstar.is_positive is True, + cstar >= 0, cstar*pi < Abs(unbranched_argument(omega)), + c1, c3, c10, c14, c15)] # 27 + pr(27) + conds += [And(p < q - 1, Eq(t, 0), Eq(phi, 0), s.is_positive is True, bstar.is_positive is True, + cstar >= 0, cstar*pi < Abs(unbranched_argument(omega)), + Abs(unbranched_argument(omega)) < (m + n - p + 1)*pi, + c1, c2, c10, c14, c15)] # 28 + pr(28) + conds += [And( + p > q + 1, Eq(s, 0), Eq(phi, 0), t.is_positive is True, bstar.is_positive is True, cstar >= 0, + cstar*pi < Abs(unbranched_argument(omega)), + Abs(unbranched_argument(omega)) < (m + n - q + 1)*pi, + c1, c3, c10, c14, c15)] # 29 + pr(29) + conds += [And(Eq(n, 0), Eq(phi, 0), s + t > 0, m.is_positive is True, cstar.is_positive is True, bstar.is_negative is True, + Abs(unbranched_argument(sigma)) < (s + t - u + 1)*pi, + c1, c2, c12, c14, c15)] # 30 + pr(30) + conds += [And(Eq(m, 0), Eq(phi, 0), s + t > v, n.is_positive is True, cstar.is_positive is True, bstar.is_negative is True, + Abs(unbranched_argument(sigma)) < (s + t - v + 1)*pi, + c1, c3, c12, c14, c15)] # 31 + pr(31) + conds += [And(Eq(n, 0), Eq(phi, 0), Eq(u, v - 1), m.is_positive is True, cstar.is_positive is True, + bstar >= 0, bstar*pi < Abs(unbranched_argument(sigma)), + Abs(unbranched_argument(sigma)) < (bstar + 1)*pi, + c1, c2, c12, c14, c15)] # 32 + pr(32) + conds += [And(Eq(m, 0), Eq(phi, 0), Eq(u, v + 1), n.is_positive is True, cstar.is_positive is True, + bstar >= 0, bstar*pi < Abs(unbranched_argument(sigma)), + Abs(unbranched_argument(sigma)) < (bstar + 1)*pi, + c1, c3, c12, c14, c15)] # 33 + pr(33) + conds += [And( + Eq(n, 0), Eq(phi, 0), u < v - 1, m.is_positive is True, cstar.is_positive is True, bstar >= 0, + bstar*pi < Abs(unbranched_argument(sigma)), + Abs(unbranched_argument(sigma)) < (s + t - u + 1)*pi, + c1, c2, c12, c14, c15)] # 34 + pr(34) + conds += [And( + Eq(m, 0), Eq(phi, 0), u > v + 1, n.is_positive is True, cstar.is_positive is True, bstar >= 0, + bstar*pi < Abs(unbranched_argument(sigma)), + Abs(unbranched_argument(sigma)) < (s + t - v + 1)*pi, + c1, c3, c12, c14, c15)] # 35 + pr(35) + + return Or(*conds) + + # NOTE An alternative, but as far as I can tell weaker, set of conditions + # can be found in [L, section 5.6.2]. + + +def _int0oo(g1, g2, x): + """ + Express integral from zero to infinity g1*g2 using a G function, + assuming the necessary conditions are fulfilled. + + Examples + ======== + + >>> from sympy.integrals.meijerint import _int0oo + >>> from sympy.abc import s, t, m + >>> from sympy import meijerg, S + >>> g1 = meijerg([], [], [-S(1)/2, 0], [], s**2*t/4) + >>> g2 = meijerg([], [], [m/2], [-m/2], t/4) + >>> _int0oo(g1, g2, t) + 4*meijerg(((0, 1/2), ()), ((m/2,), (-m/2,)), s**(-2))/s**2 + """ + # See: [L, section 5.6.2, equation (1)] + eta, _ = _get_coeff_exp(g1.argument, x) + omega, _ = _get_coeff_exp(g2.argument, x) + + def neg(l): + return [-x for x in l] + a1 = neg(g1.bm) + list(g2.an) + a2 = list(g2.aother) + neg(g1.bother) + b1 = neg(g1.an) + list(g2.bm) + b2 = list(g2.bother) + neg(g1.aother) + return meijerg(a1, a2, b1, b2, omega/eta)/eta + + +def _rewrite_inversion(fac, po, g, x): + """ Absorb ``po`` == x**s into g. """ + _, s = _get_coeff_exp(po, x) + a, b = _get_coeff_exp(g.argument, x) + + def tr(l): + return [t + s/b for t in l] + from sympy.simplify import powdenest + return (powdenest(fac/a**(s/b), polar=True), + meijerg(tr(g.an), tr(g.aother), tr(g.bm), tr(g.bother), g.argument)) + + +def _check_antecedents_inversion(g, x): + """ Check antecedents for the laplace inversion integral. """ + _debug('Checking antecedents for inversion:') + z = g.argument + _, e = _get_coeff_exp(z, x) + if e < 0: + _debug(' Flipping G.') + # We want to assume that argument gets large as |x| -> oo + return _check_antecedents_inversion(_flip_g(g), x) + + def statement_half(a, b, c, z, plus): + coeff, exponent = _get_coeff_exp(z, x) + a *= exponent + b *= coeff**c + c *= exponent + conds = [] + wp = b*exp(S.ImaginaryUnit*re(c)*pi/2) + wm = b*exp(-S.ImaginaryUnit*re(c)*pi/2) + if plus: + w = wp + else: + w = wm + conds += [And(Or(Eq(b, 0), re(c) <= 0), re(a) <= -1)] + conds += [And(Ne(b, 0), Eq(im(c), 0), re(c) > 0, re(w) < 0)] + conds += [And(Ne(b, 0), Eq(im(c), 0), re(c) > 0, re(w) <= 0, + re(a) <= -1)] + return Or(*conds) + + def statement(a, b, c, z): + """ Provide a convergence statement for z**a * exp(b*z**c), + c/f sphinx docs. """ + return And(statement_half(a, b, c, z, True), + statement_half(a, b, c, z, False)) + + # Notations from [L], section 5.7-10 + m, n, p, q = S([len(g.bm), len(g.an), len(g.ap), len(g.bq)]) + tau = m + n - p + nu = q - m - n + rho = (tau - nu)/2 + sigma = q - p + if sigma == 1: + epsilon = S.Half + elif sigma > 1: + epsilon = 1 + else: + epsilon = S.NaN + theta = ((1 - sigma)/2 + Add(*g.bq) - Add(*g.ap))/sigma + delta = g.delta + _debugf(' m=%s, n=%s, p=%s, q=%s, tau=%s, nu=%s, rho=%s, sigma=%s', + (m, n, p, q, tau, nu, rho, sigma)) + _debugf(' epsilon=%s, theta=%s, delta=%s', (epsilon, theta, delta)) + + # First check if the computation is valid. + if not (g.delta >= e/2 or (p >= 1 and p >= q)): + _debug(' Computation not valid for these parameters.') + return False + + # Now check if the inversion integral exists. + + # Test "condition A" + for a, b in itertools.product(g.an, g.bm): + if (a - b).is_integer and a > b: + _debug(' Not a valid G function.') + return False + + # There are two cases. If p >= q, we can directly use a slater expansion + # like [L], 5.2 (11). Note in particular that the asymptotics of such an + # expansion even hold when some of the parameters differ by integers, i.e. + # the formula itself would not be valid! (b/c G functions are cts. in their + # parameters) + # When p < q, we need to use the theorems of [L], 5.10. + + if p >= q: + _debug(' Using asymptotic Slater expansion.') + return And(*[statement(a - 1, 0, 0, z) for a in g.an]) + + def E(z): + return And(*[statement(a - 1, 0, 0, z) for a in g.an]) + + def H(z): + return statement(theta, -sigma, 1/sigma, z) + + def Hp(z): + return statement_half(theta, -sigma, 1/sigma, z, True) + + def Hm(z): + return statement_half(theta, -sigma, 1/sigma, z, False) + + # [L], section 5.10 + conds = [] + # Theorem 1 -- p < q from test above + conds += [And(1 <= n, 1 <= m, rho*pi - delta >= pi/2, delta > 0, + E(z*exp(S.ImaginaryUnit*pi*(nu + 1))))] + # Theorem 2, statements (2) and (3) + conds += [And(p + 1 <= m, m + 1 <= q, delta > 0, delta < pi/2, n == 0, + (m - p + 1)*pi - delta >= pi/2, + Hp(z*exp(S.ImaginaryUnit*pi*(q - m))), + Hm(z*exp(-S.ImaginaryUnit*pi*(q - m))))] + # Theorem 2, statement (5) -- p < q from test above + conds += [And(m == q, n == 0, delta > 0, + (sigma + epsilon)*pi - delta >= pi/2, H(z))] + # Theorem 3, statements (6) and (7) + conds += [And(Or(And(p <= q - 2, 1 <= tau, tau <= sigma/2), + And(p + 1 <= m + n, m + n <= (p + q)/2)), + delta > 0, delta < pi/2, (tau + 1)*pi - delta >= pi/2, + Hp(z*exp(S.ImaginaryUnit*pi*nu)), + Hm(z*exp(-S.ImaginaryUnit*pi*nu)))] + # Theorem 4, statements (10) and (11) -- p < q from test above + conds += [And(1 <= m, rho > 0, delta > 0, delta + rho*pi < pi/2, + (tau + epsilon)*pi - delta >= pi/2, + Hp(z*exp(S.ImaginaryUnit*pi*nu)), + Hm(z*exp(-S.ImaginaryUnit*pi*nu)))] + # Trivial case + conds += [m == 0] + + # TODO + # Theorem 5 is quite general + # Theorem 6 contains special cases for q=p+1 + + return Or(*conds) + + +def _int_inversion(g, x, t): + """ + Compute the laplace inversion integral, assuming the formula applies. + """ + b, a = _get_coeff_exp(g.argument, x) + C, g = _inflate_fox_h(meijerg(g.an, g.aother, g.bm, g.bother, b/t**a), -a) + return C/t*g + + +#################################################################### +# Finally, the real meat. +#################################################################### + +_lookup_table = None + + +@cacheit +@timeit +def _rewrite_single(f, x, recursive=True): + """ + Try to rewrite f as a sum of single G functions of the form + C*x**s*G(a*x**b), where b is a rational number and C is independent of x. + We guarantee that result.argument.as_coeff_mul(x) returns (a, (x**b,)) + or (a, ()). + Returns a list of tuples (C, s, G) and a condition cond. + Returns None on failure. + """ + from .transforms import (mellin_transform, inverse_mellin_transform, + IntegralTransformError, MellinTransformStripError) + + global _lookup_table + if not _lookup_table: + _lookup_table = {} + _create_lookup_table(_lookup_table) + + if isinstance(f, meijerg): + coeff, m = factor(f.argument, x).as_coeff_mul(x) + if len(m) > 1: + return None + m = m[0] + if m.is_Pow: + if m.base != x or not m.exp.is_Rational: + return None + elif m != x: + return None + return [(1, 0, meijerg(f.an, f.aother, f.bm, f.bother, coeff*m))], True + + f_ = f + f = f.subs(x, z) + t = _mytype(f, z) + if t in _lookup_table: + l = _lookup_table[t] + for formula, terms, cond, hint in l: + subs = f.match(formula, old=True) + if subs: + subs_ = {} + for fro, to in subs.items(): + subs_[fro] = unpolarify(polarify(to, lift=True), + exponents_only=True) + subs = subs_ + if not isinstance(hint, bool): + hint = hint.subs(subs) + if hint == False: + continue + if not isinstance(cond, (bool, BooleanAtom)): + cond = unpolarify(cond.subs(subs)) + if _eval_cond(cond) == False: + continue + if not isinstance(terms, list): + terms = terms(subs) + res = [] + for fac, g in terms: + r1 = _get_coeff_exp(unpolarify(fac.subs(subs).subs(z, x), + exponents_only=True), x) + try: + g = g.subs(subs).subs(z, x) + except ValueError: + continue + # NOTE these substitutions can in principle introduce oo, + # zoo and other absurdities. It shouldn't matter, + # but better be safe. + if Tuple(*(r1 + (g,))).has(S.Infinity, S.ComplexInfinity, S.NegativeInfinity): + continue + g = meijerg(g.an, g.aother, g.bm, g.bother, + unpolarify(g.argument, exponents_only=True)) + res.append(r1 + (g,)) + if res: + return res, cond + + # try recursive mellin transform + if not recursive: + return None + _debug('Trying recursive Mellin transform method.') + + def my_imt(F, s, x, strip): + """ Calling simplify() all the time is slow and not helpful, since + most of the time it only factors things in a way that has to be + un-done anyway. But sometimes it can remove apparent poles. """ + # XXX should this be in inverse_mellin_transform? + try: + return inverse_mellin_transform(F, s, x, strip, + as_meijerg=True, needeval=True) + except MellinTransformStripError: + from sympy.simplify import simplify + return inverse_mellin_transform( + simplify(cancel(expand(F))), s, x, strip, + as_meijerg=True, needeval=True) + f = f_ + s = _dummy('s', 'rewrite-single', f) + # to avoid infinite recursion, we have to force the two g functions case + + def my_integrator(f, x): + r = _meijerint_definite_4(f, x, only_double=True) + if r is not None: + from sympy.simplify import hyperexpand + res, cond = r + res = _my_unpolarify(hyperexpand(res, rewrite='nonrepsmall')) + return Piecewise((res, cond), + (Integral(f, (x, S.Zero, S.Infinity)), True)) + return Integral(f, (x, S.Zero, S.Infinity)) + try: + F, strip, _ = mellin_transform(f, x, s, integrator=my_integrator, + simplify=False, needeval=True) + g = my_imt(F, s, x, strip) + except IntegralTransformError: + g = None + if g is None: + # We try to find an expression by analytic continuation. + # (also if the dummy is already in the expression, there is no point in + # putting in another one) + a = _dummy_('a', 'rewrite-single') + if a not in f.free_symbols and _is_analytic(f, x): + try: + F, strip, _ = mellin_transform(f.subs(x, a*x), x, s, + integrator=my_integrator, + needeval=True, simplify=False) + g = my_imt(F, s, x, strip).subs(a, 1) + except IntegralTransformError: + g = None + if g is None or g.has(S.Infinity, S.NaN, S.ComplexInfinity): + _debug('Recursive Mellin transform failed.') + return None + args = Add.make_args(g) + res = [] + for f in args: + c, m = f.as_coeff_mul(x) + if len(m) > 1: + raise NotImplementedError('Unexpected form...') + g = m[0] + a, b = _get_coeff_exp(g.argument, x) + res += [(c, 0, meijerg(g.an, g.aother, g.bm, g.bother, + unpolarify(polarify( + a, lift=True), exponents_only=True) + *x**b))] + _debug('Recursive Mellin transform worked:', g) + return res, True + + +def _rewrite1(f, x, recursive=True): + """ + Try to rewrite ``f`` using a (sum of) single G functions with argument a*x**b. + Return fac, po, g such that f = fac*po*g, fac is independent of ``x``. + and po = x**s. + Here g is a result from _rewrite_single. + Return None on failure. + """ + fac, po, g = _split_mul(f, x) + g = _rewrite_single(g, x, recursive) + if g: + return fac, po, g[0], g[1] + + +def _rewrite2(f, x): + """ + Try to rewrite ``f`` as a product of two G functions of arguments a*x**b. + Return fac, po, g1, g2 such that f = fac*po*g1*g2, where fac is + independent of x and po is x**s. + Here g1 and g2 are results of _rewrite_single. + Returns None on failure. + """ + fac, po, g = _split_mul(f, x) + if any(_rewrite_single(expr, x, False) is None for expr in _mul_args(g)): + return None + l = _mul_as_two_parts(g) + if not l: + return None + l = list(ordered(l, [ + lambda p: max(len(_exponents(p[0], x)), len(_exponents(p[1], x))), + lambda p: max(len(_functions(p[0], x)), len(_functions(p[1], x))), + lambda p: max(len(_find_splitting_points(p[0], x)), + len(_find_splitting_points(p[1], x)))])) + + for recursive, (fac1, fac2) in itertools.product((False, True), l): + g1 = _rewrite_single(fac1, x, recursive) + g2 = _rewrite_single(fac2, x, recursive) + if g1 and g2: + cond = And(g1[1], g2[1]) + if cond != False: + return fac, po, g1[0], g2[0], cond + + +def meijerint_indefinite(f, x): + """ + Compute an indefinite integral of ``f`` by rewriting it as a G function. + + Examples + ======== + + >>> from sympy.integrals.meijerint import meijerint_indefinite + >>> from sympy import sin + >>> from sympy.abc import x + >>> meijerint_indefinite(sin(x), x) + -cos(x) + """ + f = sympify(f) + results = [] + for a in sorted(_find_splitting_points(f, x) | {S.Zero}, key=default_sort_key): + res = _meijerint_indefinite_1(f.subs(x, x + a), x) + if not res: + continue + res = res.subs(x, x - a) + if _has(res, hyper, meijerg): + results.append(res) + else: + return res + if f.has(HyperbolicFunction): + _debug('Try rewriting hyperbolics in terms of exp.') + rv = meijerint_indefinite( + _rewrite_hyperbolics_as_exp(f), x) + if rv: + if not isinstance(rv, list): + from sympy.simplify.radsimp import collect + return collect(factor_terms(rv), rv.atoms(exp)) + results.extend(rv) + if results: + return next(ordered(results)) + + +def _meijerint_indefinite_1(f, x): + """ Helper that does not attempt any substitution. """ + _debug('Trying to compute the indefinite integral of', f, 'wrt', x) + from sympy.simplify import hyperexpand, powdenest + + gs = _rewrite1(f, x) + if gs is None: + # Note: the code that calls us will do expand() and try again + return None + + fac, po, gl, cond = gs + _debug(' could rewrite:', gs) + res = S.Zero + for C, s, g in gl: + a, b = _get_coeff_exp(g.argument, x) + _, c = _get_coeff_exp(po, x) + c += s + + # we do a substitution t=a*x**b, get integrand fac*t**rho*g + fac_ = fac * C * x**(1 + c) / b + rho = (c + 1)/b + + # we now use t**rho*G(params, t) = G(params + rho, t) + # [L, page 150, equation (4)] + # and integral G(params, t) dt = G(1, params+1, 0, t) + # (or a similar expression with 1 and 0 exchanged ... pick the one + # which yields a well-defined function) + # [R, section 5] + # (Note that this dummy will immediately go away again, so we + # can safely pass S.One for ``expr``.) + t = _dummy('t', 'meijerint-indefinite', S.One) + + def tr(p): + return [a + rho for a in p] + if any(b.is_integer and (b <= 0) == True for b in tr(g.bm)): + r = -meijerg( + list(g.an), list(g.aother) + [1-rho], list(g.bm) + [-rho], list(g.bother), t) + else: + r = meijerg( + list(g.an) + [1-rho], list(g.aother), list(g.bm), list(g.bother) + [-rho], t) + # The antiderivative is most often expected to be defined + # in the neighborhood of x = 0. + if b.is_extended_nonnegative and not f.subs(x, 0).has(S.NaN, S.ComplexInfinity): + place = 0 # Assume we can expand at zero + else: + place = None + r = hyperexpand(r.subs(t, a*x**b), place=place) + + # now substitute back + # Note: we really do want the powers of x to combine. + res += powdenest(fac_*r, polar=True) + + def _clean(res): + """This multiplies out superfluous powers of x we created, and chops off + constants: + + >> _clean(x*(exp(x)/x - 1/x) + 3) + exp(x) + + cancel is used before mul_expand since it is possible for an + expression to have an additive constant that does not become isolated + with simple expansion. Such a situation was identified in issue 6369: + + Examples + ======== + + >>> from sympy import sqrt, cancel + >>> from sympy.abc import x + >>> a = sqrt(2*x + 1) + >>> bad = (3*x*a**5 + 2*x - a**5 + 1)/a**2 + >>> bad.expand().as_independent(x)[0] + 0 + >>> cancel(bad).expand().as_independent(x)[0] + 1 + """ + res = expand_mul(cancel(res), deep=False) + return Add._from_args(res.as_coeff_add(x)[1]) + + res = piecewise_fold(res, evaluate=None) + if res.is_Piecewise: + newargs = [] + for e, c in res.args: + e = _my_unpolarify(_clean(e)) + newargs += [(e, c)] + res = Piecewise(*newargs, evaluate=False) + else: + res = _my_unpolarify(_clean(res)) + return Piecewise((res, _my_unpolarify(cond)), (Integral(f, x), True)) + + +@timeit +def meijerint_definite(f, x, a, b): + """ + Integrate ``f`` over the interval [``a``, ``b``], by rewriting it as a product + of two G functions, or as a single G function. + + Return res, cond, where cond are convergence conditions. + + Examples + ======== + + >>> from sympy.integrals.meijerint import meijerint_definite + >>> from sympy import exp, oo + >>> from sympy.abc import x + >>> meijerint_definite(exp(-x**2), x, -oo, oo) + (sqrt(pi), True) + + This function is implemented as a succession of functions + meijerint_definite, _meijerint_definite_2, _meijerint_definite_3, + _meijerint_definite_4. Each function in the list calls the next one + (presumably) several times. This means that calling meijerint_definite + can be very costly. + """ + # This consists of three steps: + # 1) Change the integration limits to 0, oo + # 2) Rewrite in terms of G functions + # 3) Evaluate the integral + # + # There are usually several ways of doing this, and we want to try all. + # This function does (1), calls _meijerint_definite_2 for step (2). + _debugf('Integrating %s wrt %s from %s to %s.', (f, x, a, b)) + f = sympify(f) + if f.has(DiracDelta): + _debug('Integrand has DiracDelta terms - giving up.') + return None + + if f.has(SingularityFunction): + _debug('Integrand has Singularity Function terms - giving up.') + return None + + f_, x_, a_, b_ = f, x, a, b + + # Let's use a dummy in case any of the boundaries has x. + d = Dummy('x') + f = f.subs(x, d) + x = d + + if a == b: + return (S.Zero, True) + + results = [] + if a is S.NegativeInfinity and b is not S.Infinity: + return meijerint_definite(f.subs(x, -x), x, -b, -a) + + elif a is S.NegativeInfinity: + # Integrating -oo to oo. We need to find a place to split the integral. + _debug(' Integrating -oo to +oo.') + innermost = _find_splitting_points(f, x) + _debug(' Sensible splitting points:', innermost) + for c in sorted(innermost, key=default_sort_key, reverse=True) + [S.Zero]: + _debug(' Trying to split at', c) + if not c.is_extended_real: + _debug(' Non-real splitting point.') + continue + res1 = _meijerint_definite_2(f.subs(x, x + c), x) + if res1 is None: + _debug(' But could not compute first integral.') + continue + res2 = _meijerint_definite_2(f.subs(x, c - x), x) + if res2 is None: + _debug(' But could not compute second integral.') + continue + res1, cond1 = res1 + res2, cond2 = res2 + cond = _condsimp(And(cond1, cond2)) + if cond == False: + _debug(' But combined condition is always false.') + continue + res = res1 + res2 + return res, cond + + elif a is S.Infinity: + res = meijerint_definite(f, x, b, S.Infinity) + return -res[0], res[1] + + elif (a, b) == (S.Zero, S.Infinity): + # This is a common case - try it directly first. + res = _meijerint_definite_2(f, x) + if res: + if _has(res[0], meijerg): + results.append(res) + else: + return res + + else: + if b is S.Infinity: + for split in _find_splitting_points(f, x): + if (a - split >= 0) == True: + _debugf('Trying x -> x + %s', split) + res = _meijerint_definite_2(f.subs(x, x + split) + *Heaviside(x + split - a), x) + if res: + if _has(res[0], meijerg): + results.append(res) + else: + return res + + f = f.subs(x, x + a) + b = b - a + a = 0 + if b is not S.Infinity: + phi = exp(S.ImaginaryUnit*arg(b)) + b = Abs(b) + f = f.subs(x, phi*x) + f *= Heaviside(b - x)*phi + b = S.Infinity + + _debug('Changed limits to', a, b) + _debug('Changed function to', f) + res = _meijerint_definite_2(f, x) + if res: + if _has(res[0], meijerg): + results.append(res) + else: + return res + if f_.has(HyperbolicFunction): + _debug('Try rewriting hyperbolics in terms of exp.') + rv = meijerint_definite( + _rewrite_hyperbolics_as_exp(f_), x_, a_, b_) + if rv: + if not isinstance(rv, list): + from sympy.simplify.radsimp import collect + rv = (collect(factor_terms(rv[0]), rv[0].atoms(exp)),) + rv[1:] + return rv + results.extend(rv) + if results: + return next(ordered(results)) + + +def _guess_expansion(f, x): + """ Try to guess sensible rewritings for integrand f(x). """ + res = [(f, 'original integrand')] + + orig = res[-1][0] + saw = {orig} + expanded = expand_mul(orig) + if expanded not in saw: + res += [(expanded, 'expand_mul')] + saw.add(expanded) + + expanded = expand(orig) + if expanded not in saw: + res += [(expanded, 'expand')] + saw.add(expanded) + + if orig.has(TrigonometricFunction, HyperbolicFunction): + expanded = expand_mul(expand_trig(orig)) + if expanded not in saw: + res += [(expanded, 'expand_trig, expand_mul')] + saw.add(expanded) + + if orig.has(cos, sin): + from sympy.simplify.fu import sincos_to_sum + reduced = sincos_to_sum(orig) + if reduced not in saw: + res += [(reduced, 'trig power reduction')] + saw.add(reduced) + + return res + + +def _meijerint_definite_2(f, x): + """ + Try to integrate f dx from zero to infinity. + + The body of this function computes various 'simplifications' + f1, f2, ... of f (e.g. by calling expand_mul(), trigexpand() + - see _guess_expansion) and calls _meijerint_definite_3 with each of + these in succession. + If _meijerint_definite_3 succeeds with any of the simplified functions, + returns this result. + """ + # This function does preparation for (2), calls + # _meijerint_definite_3 for (2) and (3) combined. + + # use a positive dummy - we integrate from 0 to oo + # XXX if a nonnegative symbol is used there will be test failures + dummy = _dummy('x', 'meijerint-definite2', f, positive=True) + f = f.subs(x, dummy) + x = dummy + + if f == 0: + return S.Zero, True + + for g, explanation in _guess_expansion(f, x): + _debug('Trying', explanation) + res = _meijerint_definite_3(g, x) + if res: + return res + + +def _meijerint_definite_3(f, x): + """ + Try to integrate f dx from zero to infinity. + + This function calls _meijerint_definite_4 to try to compute the + integral. If this fails, it tries using linearity. + """ + res = _meijerint_definite_4(f, x) + if res and res[1] != False: + return res + if f.is_Add: + _debug('Expanding and evaluating all terms.') + ress = [_meijerint_definite_4(g, x) for g in f.args] + if all(r is not None for r in ress): + conds = [] + res = S.Zero + for r, c in ress: + res += r + conds += [c] + c = And(*conds) + if c != False: + return res, c + + +def _my_unpolarify(f): + return _eval_cond(unpolarify(f)) + + +@timeit +def _meijerint_definite_4(f, x, only_double=False): + """ + Try to integrate f dx from zero to infinity. + + Explanation + =========== + + This function tries to apply the integration theorems found in literature, + i.e. it tries to rewrite f as either one or a product of two G-functions. + + The parameter ``only_double`` is used internally in the recursive algorithm + to disable trying to rewrite f as a single G-function. + """ + from sympy.simplify import hyperexpand + # This function does (2) and (3) + _debug('Integrating', f) + # Try single G function. + if not only_double: + gs = _rewrite1(f, x, recursive=False) + if gs is not None: + fac, po, g, cond = gs + _debug('Could rewrite as single G function:', fac, po, g) + res = S.Zero + for C, s, f in g: + if C == 0: + continue + C, f = _rewrite_saxena_1(fac*C, po*x**s, f, x) + res += C*_int0oo_1(f, x) + cond = And(cond, _check_antecedents_1(f, x)) + if cond == False: + break + cond = _my_unpolarify(cond) + if cond == False: + _debug('But cond is always False.') + else: + _debug('Result before branch substitutions is:', res) + return _my_unpolarify(hyperexpand(res)), cond + + # Try two G functions. + gs = _rewrite2(f, x) + if gs is not None: + for full_pb in [False, True]: + fac, po, g1, g2, cond = gs + _debug('Could rewrite as two G functions:', fac, po, g1, g2) + res = S.Zero + for C1, s1, f1 in g1: + for C2, s2, f2 in g2: + r = _rewrite_saxena(fac*C1*C2, po*x**(s1 + s2), + f1, f2, x, full_pb) + if r is None: + _debug('Non-rational exponents.') + return + C, f1_, f2_ = r + _debug('Saxena subst for yielded:', C, f1_, f2_) + cond = And(cond, _check_antecedents(f1_, f2_, x)) + if cond == False: + break + res += C*_int0oo(f1_, f2_, x) + else: + continue + break + cond = _my_unpolarify(cond) + if cond == False: + _debugf('But cond is always False (full_pb=%s).', full_pb) + else: + _debugf('Result before branch substitutions is: %s', (res, )) + if only_double: + return res, cond + return _my_unpolarify(hyperexpand(res)), cond + + +def meijerint_inversion(f, x, t): + r""" + Compute the inverse laplace transform + $\int_{c+i\infty}^{c-i\infty} f(x) e^{tx}\, dx$, + for real c larger than the real part of all singularities of ``f``. + + Note that ``t`` is always assumed real and positive. + + Return None if the integral does not exist or could not be evaluated. + + Examples + ======== + + >>> from sympy.abc import x, t + >>> from sympy.integrals.meijerint import meijerint_inversion + >>> meijerint_inversion(1/x, x, t) + Heaviside(t) + """ + f_ = f + t_ = t + t = Dummy('t', polar=True) # We don't want sqrt(t**2) = abs(t) etc + f = f.subs(t_, t) + _debug('Laplace-inverting', f) + if not _is_analytic(f, x): + _debug('But expression is not analytic.') + return None + # Exponentials correspond to shifts; we filter them out and then + # shift the result later. If we are given an Add this will not + # work, but the calling code will take care of that. + shift = S.Zero + + if f.is_Mul: + args = list(f.args) + elif isinstance(f, exp): + args = [f] + else: + args = None + + if args: + newargs = [] + exponentials = [] + while args: + arg = args.pop() + if isinstance(arg, exp): + arg2 = expand(arg) + if arg2.is_Mul: + args += arg2.args + continue + try: + a, b = _get_coeff_exp(arg.args[0], x) + except _CoeffExpValueError: + b = 0 + if b == 1: + exponentials.append(a) + else: + newargs.append(arg) + elif arg.is_Pow: + arg2 = expand(arg) + if arg2.is_Mul: + args += arg2.args + continue + if x not in arg.base.free_symbols: + try: + a, b = _get_coeff_exp(arg.exp, x) + except _CoeffExpValueError: + b = 0 + if b == 1: + exponentials.append(a*log(arg.base)) + newargs.append(arg) + else: + newargs.append(arg) + shift = Add(*exponentials) + f = Mul(*newargs) + + if x not in f.free_symbols: + _debug('Expression consists of constant and exp shift:', f, shift) + cond = Eq(im(shift), 0) + if cond == False: + _debug('but shift is nonreal, cannot be a Laplace transform') + return None + res = f*DiracDelta(t + shift) + _debug('Result is a delta function, possibly conditional:', res, cond) + # cond is True or Eq + return Piecewise((res.subs(t, t_), cond)) + + gs = _rewrite1(f, x) + if gs is not None: + fac, po, g, cond = gs + _debug('Could rewrite as single G function:', fac, po, g) + res = S.Zero + for C, s, f in g: + C, f = _rewrite_inversion(fac*C, po*x**s, f, x) + res += C*_int_inversion(f, x, t) + cond = And(cond, _check_antecedents_inversion(f, x)) + if cond == False: + break + cond = _my_unpolarify(cond) + if cond == False: + _debug('But cond is always False.') + else: + _debug('Result before branch substitution:', res) + from sympy.simplify import hyperexpand + res = _my_unpolarify(hyperexpand(res)) + if not res.has(Heaviside): + res *= Heaviside(t) + res = res.subs(t, t + shift) + if not isinstance(cond, bool): + cond = cond.subs(t, t + shift) + from .transforms import InverseLaplaceTransform + return Piecewise((res.subs(t, t_), cond), + (InverseLaplaceTransform(f_.subs(t, t_), x, t_, None), True)) diff --git a/.venv/lib/python3.13/site-packages/sympy/integrals/meijerint_doc.py b/.venv/lib/python3.13/site-packages/sympy/integrals/meijerint_doc.py new file mode 100644 index 0000000000000000000000000000000000000000..1df385db6b6fe6d46d4a14217aa53d1fdd9670b5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/integrals/meijerint_doc.py @@ -0,0 +1,38 @@ +""" This module cooks up a docstring when imported. Its only purpose is to + be displayed in the sphinx documentation. """ + +from __future__ import annotations +from typing import Any + +from sympy.integrals.meijerint import _create_lookup_table +from sympy.core.add import Add +from sympy.core.basic import Basic +from sympy.core.expr import Expr +from sympy.core.relational import Eq +from sympy.core.symbol import Symbol +from sympy.printing.latex import latex + +t: dict[tuple[type[Basic], ...], list[Any]] = {} +_create_lookup_table(t) + + +doc = "" +for about, category in t.items(): + if about == (): + doc += 'Elementary functions:\n\n' + else: + doc += 'Functions involving ' + ', '.join('`%s`' % latex( + list(category[0][0].atoms(func))[0]) for func in about) + ':\n\n' + for formula, gs, cond, hint in category: + if not isinstance(gs, list): + g: Expr = Symbol('\\text{generated}') + else: + g = Add(*[fac*f for (fac, f) in gs]) + obj = Eq(formula, g) + if cond is True: + cond = "" + else: + cond = ',\\text{ if } %s' % latex(cond) + doc += ".. math::\n %s%s\n\n" % (latex(obj), cond) + +__doc__ = doc diff --git a/.venv/lib/python3.13/site-packages/sympy/integrals/prde.py b/.venv/lib/python3.13/site-packages/sympy/integrals/prde.py new file mode 100644 index 0000000000000000000000000000000000000000..28e91ea0ff3a82cbeca24c6ed4267503ceb758b5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/integrals/prde.py @@ -0,0 +1,1333 @@ +""" +Algorithms for solving Parametric Risch Differential Equations. + +The methods used for solving Parametric Risch Differential Equations parallel +those for solving Risch Differential Equations. See the outline in the +docstring of rde.py for more information. + +The Parametric Risch Differential Equation problem is, given f, g1, ..., gm in +K(t), to determine if there exist y in K(t) and c1, ..., cm in Const(K) such +that Dy + f*y == Sum(ci*gi, (i, 1, m)), and to find such y and ci if they exist. + +For the algorithms here G is a list of tuples of factions of the terms on the +right hand side of the equation (i.e., gi in k(t)), and Q is a list of terms on +the right hand side of the equation (i.e., qi in k[t]). See the docstring of +each function for more information. +""" +import itertools +from functools import reduce + +from sympy.core.intfunc import ilcm +from sympy.core import Dummy, Add, Mul, Pow, S +from sympy.integrals.rde import (order_at, order_at_oo, weak_normalizer, + bound_degree) +from sympy.integrals.risch import (gcdex_diophantine, frac_in, derivation, + residue_reduce, splitfactor, residue_reduce_derivation, DecrementLevel, + recognize_log_derivative) +from sympy.polys import Poly, lcm, cancel, sqf_list +from sympy.polys.polymatrix import PolyMatrix as Matrix +from sympy.solvers import solve + +zeros = Matrix.zeros +eye = Matrix.eye + + +def prde_normal_denom(fa, fd, G, DE): + """ + Parametric Risch Differential Equation - Normal part of the denominator. + + Explanation + =========== + + Given a derivation D on k[t] and f, g1, ..., gm in k(t) with f weakly + normalized with respect to t, return the tuple (a, b, G, h) such that + a, h in k[t], b in k, G = [g1, ..., gm] in k(t)^m, and for any solution + c1, ..., cm in Const(k) and y in k(t) of Dy + f*y == Sum(ci*gi, (i, 1, m)), + q == y*h in k satisfies a*Dq + b*q == Sum(ci*Gi, (i, 1, m)). + """ + dn, ds = splitfactor(fd, DE) + Gas, Gds = list(zip(*G)) + gd = reduce(lambda i, j: i.lcm(j), Gds, Poly(1, DE.t)) + en, es = splitfactor(gd, DE) + + p = dn.gcd(en) + h = en.gcd(en.diff(DE.t)).quo(p.gcd(p.diff(DE.t))) + + a = dn*h + c = a*h + + ba = a*fa - dn*derivation(h, DE)*fd + ba, bd = ba.cancel(fd, include=True) + + G = [(c*A).cancel(D, include=True) for A, D in G] + + return (a, (ba, bd), G, h) + +def real_imag(ba, bd, gen): + """ + Helper function, to get the real and imaginary part of a rational function + evaluated at sqrt(-1) without actually evaluating it at sqrt(-1). + + Explanation + =========== + + Separates the even and odd power terms by checking the degree of terms wrt + mod 4. Returns a tuple (ba[0], ba[1], bd) where ba[0] is real part + of the numerator ba[1] is the imaginary part and bd is the denominator + of the rational function. + """ + bd = bd.as_poly(gen).as_dict() + ba = ba.as_poly(gen).as_dict() + denom_real = [value if key[0] % 4 == 0 else -value if key[0] % 4 == 2 else 0 for key, value in bd.items()] + denom_imag = [value if key[0] % 4 == 1 else -value if key[0] % 4 == 3 else 0 for key, value in bd.items()] + bd_real = sum(r for r in denom_real) + bd_imag = sum(r for r in denom_imag) + num_real = [value if key[0] % 4 == 0 else -value if key[0] % 4 == 2 else 0 for key, value in ba.items()] + num_imag = [value if key[0] % 4 == 1 else -value if key[0] % 4 == 3 else 0 for key, value in ba.items()] + ba_real = sum(r for r in num_real) + ba_imag = sum(r for r in num_imag) + ba = ((ba_real*bd_real + ba_imag*bd_imag).as_poly(gen), (ba_imag*bd_real - ba_real*bd_imag).as_poly(gen)) + bd = (bd_real*bd_real + bd_imag*bd_imag).as_poly(gen) + return (ba[0], ba[1], bd) + + +def prde_special_denom(a, ba, bd, G, DE, case='auto'): + """ + Parametric Risch Differential Equation - Special part of the denominator. + + Explanation + =========== + + Case is one of {'exp', 'tan', 'primitive'} for the hyperexponential, + hypertangent, and primitive cases, respectively. For the hyperexponential + (resp. hypertangent) case, given a derivation D on k[t] and a in k[t], + b in k, and g1, ..., gm in k(t) with Dt/t in k (resp. Dt/(t**2 + 1) in + k, sqrt(-1) not in k), a != 0, and gcd(a, t) == 1 (resp. + gcd(a, t**2 + 1) == 1), return the tuple (A, B, GG, h) such that A, B, h in + k[t], GG = [gg1, ..., ggm] in k(t)^m, and for any solution c1, ..., cm in + Const(k) and q in k of a*Dq + b*q == Sum(ci*gi, (i, 1, m)), r == q*h in + k[t] satisfies A*Dr + B*r == Sum(ci*ggi, (i, 1, m)). + + For case == 'primitive', k == k[t], so it returns (a, b, G, 1) in this + case. + """ + # TODO: Merge this with the very similar special_denom() in rde.py + if case == 'auto': + case = DE.case + + if case == 'exp': + p = Poly(DE.t, DE.t) + elif case == 'tan': + p = Poly(DE.t**2 + 1, DE.t) + elif case in ('primitive', 'base'): + B = ba.quo(bd) + return (a, B, G, Poly(1, DE.t)) + else: + raise ValueError("case must be one of {'exp', 'tan', 'primitive', " + "'base'}, not %s." % case) + + nb = order_at(ba, p, DE.t) - order_at(bd, p, DE.t) + nc = min(order_at(Ga, p, DE.t) - order_at(Gd, p, DE.t) for Ga, Gd in G) + n = min(0, nc - min(0, nb)) + if not nb: + # Possible cancellation. + if case == 'exp': + dcoeff = DE.d.quo(Poly(DE.t, DE.t)) + with DecrementLevel(DE): # We are guaranteed to not have problems, + # because case != 'base'. + alphaa, alphad = frac_in(-ba.eval(0)/bd.eval(0)/a.eval(0), DE.t) + etaa, etad = frac_in(dcoeff, DE.t) + A = parametric_log_deriv(alphaa, alphad, etaa, etad, DE) + if A is not None: + Q, m, z = A + if Q == 1: + n = min(n, m) + + elif case == 'tan': + dcoeff = DE.d.quo(Poly(DE.t**2 + 1, DE.t)) + with DecrementLevel(DE): # We are guaranteed to not have problems, + # because case != 'base'. + betaa, alphaa, alphad = real_imag(ba, bd*a, DE.t) + betad = alphad + etaa, etad = frac_in(dcoeff, DE.t) + if recognize_log_derivative(Poly(2, DE.t)*betaa, betad, DE): + A = parametric_log_deriv(alphaa, alphad, etaa, etad, DE) + B = parametric_log_deriv(betaa, betad, etaa, etad, DE) + if A is not None and B is not None: + Q, s, z = A + # TODO: Add test + if Q == 1: + n = min(n, s/2) + + N = max(0, -nb) + pN = p**N + pn = p**-n # This is 1/h + + A = a*pN + B = ba*pN.quo(bd) + Poly(n, DE.t)*a*derivation(p, DE).quo(p)*pN + G = [(Ga*pN*pn).cancel(Gd, include=True) for Ga, Gd in G] + h = pn + + # (a*p**N, (b + n*a*Dp/p)*p**N, g1*p**(N - n), ..., gm*p**(N - n), p**-n) + return (A, B, G, h) + + +def prde_linear_constraints(a, b, G, DE): + """ + Parametric Risch Differential Equation - Generate linear constraints on the constants. + + Explanation + =========== + + Given a derivation D on k[t], a, b, in k[t] with gcd(a, b) == 1, and + G = [g1, ..., gm] in k(t)^m, return Q = [q1, ..., qm] in k[t]^m and a + matrix M with entries in k(t) such that for any solution c1, ..., cm in + Const(k) and p in k[t] of a*Dp + b*p == Sum(ci*gi, (i, 1, m)), + (c1, ..., cm) is a solution of Mx == 0, and p and the ci satisfy + a*Dp + b*p == Sum(ci*qi, (i, 1, m)). + + Because M has entries in k(t), and because Matrix does not play well with + Poly, M will be a Matrix of Basic expressions. + """ + m = len(G) + + Gns, Gds = list(zip(*G)) + d = reduce(lambda i, j: i.lcm(j), Gds) + d = Poly(d, field=True) + Q = [(ga*(d).quo(gd)).div(d) for ga, gd in G] + + if not all(ri.is_zero for _, ri in Q): + N = max(ri.degree(DE.t) for _, ri in Q) + M = Matrix(N + 1, m, lambda i, j: Q[j][1].nth(i), DE.t) + else: + M = Matrix(0, m, [], DE.t) # No constraints, return the empty matrix. + + qs, _ = list(zip(*Q)) + return (qs, M) + +def poly_linear_constraints(p, d): + """ + Given p = [p1, ..., pm] in k[t]^m and d in k[t], return + q = [q1, ..., qm] in k[t]^m and a matrix M with entries in k such + that Sum(ci*pi, (i, 1, m)), for c1, ..., cm in k, is divisible + by d if and only if (c1, ..., cm) is a solution of Mx = 0, in + which case the quotient is Sum(ci*qi, (i, 1, m)). + """ + m = len(p) + q, r = zip(*[pi.div(d) for pi in p]) + + if not all(ri.is_zero for ri in r): + n = max(ri.degree() for ri in r) + M = Matrix(n + 1, m, lambda i, j: r[j].nth(i), d.gens) + else: + M = Matrix(0, m, [], d.gens) # No constraints. + + return q, M + +def constant_system(A, u, DE): + """ + Generate a system for the constant solutions. + + Explanation + =========== + + Given a differential field (K, D) with constant field C = Const(K), a Matrix + A, and a vector (Matrix) u with coefficients in K, returns the tuple + (B, v, s), where B is a Matrix with coefficients in C and v is a vector + (Matrix) such that either v has coefficients in C, in which case s is True + and the solutions in C of Ax == u are exactly all the solutions of Bx == v, + or v has a non-constant coefficient, in which case s is False Ax == u has no + constant solution. + + This algorithm is used both in solving parametric problems and in + determining if an element a of K is a derivative of an element of K or the + logarithmic derivative of a K-radical using the structure theorem approach. + + Because Poly does not play well with Matrix yet, this algorithm assumes that + all matrix entries are Basic expressions. + """ + if not A: + return A, u + Au = A.row_join(u) + Au, _ = Au.rref() + # Warning: This will NOT return correct results if cancel() cannot reduce + # an identically zero expression to 0. The danger is that we might + # incorrectly prove that an integral is nonelementary (such as + # risch_integrate(exp((sin(x)**2 + cos(x)**2 - 1)*x**2), x). + # But this is a limitation in computer algebra in general, and implicit + # in the correctness of the Risch Algorithm is the computability of the + # constant field (actually, this same correctness problem exists in any + # algorithm that uses rref()). + # + # We therefore limit ourselves to constant fields that are computable + # via the cancel() function, in order to prevent a speed bottleneck from + # calling some more complex simplification function (rational function + # coefficients will fall into this class). Furthermore, (I believe) this + # problem will only crop up if the integral explicitly contains an + # expression in the constant field that is identically zero, but cannot + # be reduced to such by cancel(). Therefore, a careful user can avoid this + # problem entirely by being careful with the sorts of expressions that + # appear in his integrand in the variables other than the integration + # variable (the structure theorems should be able to completely decide these + # problems in the integration variable). + + A, u = Au[:, :-1], Au[:, -1] + + D = lambda x: derivation(x, DE, basic=True) + + for j, i in itertools.product(range(A.cols), range(A.rows)): + if A[i, j].expr.has(*DE.T): + # This assumes that const(F(t0, ..., tn) == const(K) == F + Ri = A[i, :] + # Rm+1; m = A.rows + DAij = D(A[i, j]) + Rm1 = Ri.applyfunc(lambda x: D(x) / DAij) + um1 = D(u[i]) / DAij + + Aj = A[:, j] + A = A - Aj * Rm1 + u = u - Aj * um1 + + A = A.col_join(Rm1) + u = u.col_join(Matrix([um1], u.gens)) + + return (A, u) + + +def prde_spde(a, b, Q, n, DE): + """ + Special Polynomial Differential Equation algorithm: Parametric Version. + + Explanation + =========== + + Given a derivation D on k[t], an integer n, and a, b, q1, ..., qm in k[t] + with deg(a) > 0 and gcd(a, b) == 1, return (A, B, Q, R, n1), with + Qq = [q1, ..., qm] and R = [r1, ..., rm], such that for any solution + c1, ..., cm in Const(k) and q in k[t] of degree at most n of + a*Dq + b*q == Sum(ci*gi, (i, 1, m)), p = (q - Sum(ci*ri, (i, 1, m)))/a has + degree at most n1 and satisfies A*Dp + B*p == Sum(ci*qi, (i, 1, m)) + """ + R, Z = list(zip(*[gcdex_diophantine(b, a, qi) for qi in Q])) + + A = a + B = b + derivation(a, DE) + Qq = [zi - derivation(ri, DE) for ri, zi in zip(R, Z)] + R = list(R) + n1 = n - a.degree(DE.t) + + return (A, B, Qq, R, n1) + + +def prde_no_cancel_b_large(b, Q, n, DE): + """ + Parametric Poly Risch Differential Equation - No cancellation: deg(b) large enough. + + Explanation + =========== + + Given a derivation D on k[t], n in ZZ, and b, q1, ..., qm in k[t] with + b != 0 and either D == d/dt or deg(b) > max(0, deg(D) - 1), returns + h1, ..., hr in k[t] and a matrix A with coefficients in Const(k) such that + if c1, ..., cm in Const(k) and q in k[t] satisfy deg(q) <= n and + Dq + b*q == Sum(ci*qi, (i, 1, m)), then q = Sum(dj*hj, (j, 1, r)), where + d1, ..., dr in Const(k) and A*Matrix([[c1, ..., cm, d1, ..., dr]]).T == 0. + """ + db = b.degree(DE.t) + m = len(Q) + H = [Poly(0, DE.t)]*m + + for N, i in itertools.product(range(n, -1, -1), range(m)): # [n, ..., 0] + si = Q[i].nth(N + db)/b.LC() + sitn = Poly(si*DE.t**N, DE.t) + H[i] = H[i] + sitn + Q[i] = Q[i] - derivation(sitn, DE) - b*sitn + + if all(qi.is_zero for qi in Q): + dc = -1 + else: + dc = max(qi.degree(DE.t) for qi in Q) + M = Matrix(dc + 1, m, lambda i, j: Q[j].nth(i), DE.t) + A, u = constant_system(M, zeros(dc + 1, 1, DE.t), DE) + c = eye(m, DE.t) + A = A.row_join(zeros(A.rows, m, DE.t)).col_join(c.row_join(-c)) + + return (H, A) + + +def prde_no_cancel_b_small(b, Q, n, DE): + """ + Parametric Poly Risch Differential Equation - No cancellation: deg(b) small enough. + + Explanation + =========== + + Given a derivation D on k[t], n in ZZ, and b, q1, ..., qm in k[t] with + deg(b) < deg(D) - 1 and either D == d/dt or deg(D) >= 2, returns + h1, ..., hr in k[t] and a matrix A with coefficients in Const(k) such that + if c1, ..., cm in Const(k) and q in k[t] satisfy deg(q) <= n and + Dq + b*q == Sum(ci*qi, (i, 1, m)) then q = Sum(dj*hj, (j, 1, r)) where + d1, ..., dr in Const(k) and A*Matrix([[c1, ..., cm, d1, ..., dr]]).T == 0. + """ + m = len(Q) + H = [Poly(0, DE.t)]*m + + for N, i in itertools.product(range(n, 0, -1), range(m)): # [n, ..., 1] + si = Q[i].nth(N + DE.d.degree(DE.t) - 1)/(N*DE.d.LC()) + sitn = Poly(si*DE.t**N, DE.t) + H[i] = H[i] + sitn + Q[i] = Q[i] - derivation(sitn, DE) - b*sitn + + if b.degree(DE.t) > 0: + for i in range(m): + si = Poly(Q[i].nth(b.degree(DE.t))/b.LC(), DE.t) + H[i] = H[i] + si + Q[i] = Q[i] - derivation(si, DE) - b*si + if all(qi.is_zero for qi in Q): + dc = -1 + else: + dc = max(qi.degree(DE.t) for qi in Q) + M = Matrix(dc + 1, m, lambda i, j: Q[j].nth(i), DE.t) + A, u = constant_system(M, zeros(dc + 1, 1, DE.t), DE) + c = eye(m, DE.t) + A = A.row_join(zeros(A.rows, m, DE.t)).col_join(c.row_join(-c)) + return (H, A) + + # else: b is in k, deg(qi) < deg(Dt) + + t = DE.t + if DE.case != 'base': + with DecrementLevel(DE): + t0 = DE.t # k = k0(t0) + ba, bd = frac_in(b, t0, field=True) + Q0 = [frac_in(qi.TC(), t0, field=True) for qi in Q] + f, B = param_rischDE(ba, bd, Q0, DE) + + # f = [f1, ..., fr] in k^r and B is a matrix with + # m + r columns and entries in Const(k) = Const(k0) + # such that Dy0 + b*y0 = Sum(ci*qi, (i, 1, m)) has + # a solution y0 in k with c1, ..., cm in Const(k) + # if and only y0 = Sum(dj*fj, (j, 1, r)) where + # d1, ..., dr ar in Const(k) and + # B*Matrix([c1, ..., cm, d1, ..., dr]) == 0. + + # Transform fractions (fa, fd) in f into constant + # polynomials fa/fd in k[t]. + # (Is there a better way?) + f = [Poly(fa.as_expr()/fd.as_expr(), t, field=True) + for fa, fd in f] + B = Matrix.from_Matrix(B.to_Matrix(), t) + else: + # Base case. Dy == 0 for all y in k and b == 0. + # Dy + b*y = Sum(ci*qi) is solvable if and only if + # Sum(ci*qi) == 0 in which case the solutions are + # y = d1*f1 for f1 = 1 and any d1 in Const(k) = k. + + f = [Poly(1, t, field=True)] # r = 1 + B = Matrix([[qi.TC() for qi in Q] + [S.Zero]], DE.t) + # The condition for solvability is + # B*Matrix([c1, ..., cm, d1]) == 0 + # There are no constraints on d1. + + # Coefficients of t^j (j > 0) in Sum(ci*qi) must be zero. + d = max(qi.degree(DE.t) for qi in Q) + if d > 0: + M = Matrix(d, m, lambda i, j: Q[j].nth(i + 1), DE.t) + A, _ = constant_system(M, zeros(d, 1, DE.t), DE) + else: + # No constraints on the hj. + A = Matrix(0, m, [], DE.t) + + # Solutions of the original equation are + # y = Sum(dj*fj, (j, 1, r) + Sum(ei*hi, (i, 1, m)), + # where ei == ci (i = 1, ..., m), when + # A*Matrix([c1, ..., cm]) == 0 and + # B*Matrix([c1, ..., cm, d1, ..., dr]) == 0 + + # Build combined constraint matrix with m + r + m columns. + + r = len(f) + I = eye(m, DE.t) + A = A.row_join(zeros(A.rows, r + m, DE.t)) + B = B.row_join(zeros(B.rows, m, DE.t)) + C = I.row_join(zeros(m, r, DE.t)).row_join(-I) + + return f + H, A.col_join(B).col_join(C) + + +def prde_cancel_liouvillian(b, Q, n, DE): + """ + Pg, 237. + """ + H = [] + + # Why use DecrementLevel? Below line answers that: + # Assuming that we can solve such problems over 'k' (not k[t]) + if DE.case == 'primitive': + with DecrementLevel(DE): + ba, bd = frac_in(b, DE.t, field=True) + + for i in range(n, -1, -1): + if DE.case == 'exp': # this re-checking can be avoided + with DecrementLevel(DE): + ba, bd = frac_in(b + (i*(derivation(DE.t, DE)/DE.t)).as_poly(b.gens), + DE.t, field=True) + with DecrementLevel(DE): + Qy = [frac_in(q.nth(i), DE.t, field=True) for q in Q] + fi, Ai = param_rischDE(ba, bd, Qy, DE) + fi = [Poly(fa.as_expr()/fd.as_expr(), DE.t, field=True) + for fa, fd in fi] + Ai = Ai.set_gens(DE.t) + + ri = len(fi) + + if i == n: + M = Ai + else: + M = Ai.col_join(M.row_join(zeros(M.rows, ri, DE.t))) + + Fi, hi = [None]*ri, [None]*ri + + # from eq. on top of p.238 (unnumbered) + for j in range(ri): + hji = fi[j] * (DE.t**i).as_poly(fi[j].gens) + hi[j] = hji + # building up Sum(djn*(D(fjn*t^n) - b*fjnt^n)) + Fi[j] = -(derivation(hji, DE) - b*hji) + + H += hi + # in the next loop instead of Q it has + # to be Q + Fi taking its place + Q = Q + Fi + + return (H, M) + + +def param_poly_rischDE(a, b, q, n, DE): + """Polynomial solutions of a parametric Risch differential equation. + + Explanation + =========== + + Given a derivation D in k[t], a, b in k[t] relatively prime, and q + = [q1, ..., qm] in k[t]^m, return h = [h1, ..., hr] in k[t]^r and + a matrix A with m + r columns and entries in Const(k) such that + a*Dp + b*p = Sum(ci*qi, (i, 1, m)) has a solution p of degree <= n + in k[t] with c1, ..., cm in Const(k) if and only if p = Sum(dj*hj, + (j, 1, r)) where d1, ..., dr are in Const(k) and (c1, ..., cm, + d1, ..., dr) is a solution of Ax == 0. + """ + m = len(q) + if n < 0: + # Only the trivial zero solution is possible. + # Find relations between the qi. + if all(qi.is_zero for qi in q): + return [], zeros(1, m, DE.t) # No constraints. + + N = max(qi.degree(DE.t) for qi in q) + M = Matrix(N + 1, m, lambda i, j: q[j].nth(i), DE.t) + A, _ = constant_system(M, zeros(M.rows, 1, DE.t), DE) + + return [], A + + if a.is_ground: + # Normalization: a = 1. + a = a.LC() + b, q = b.to_field().exquo_ground(a), [qi.to_field().exquo_ground(a) for qi in q] + + if not b.is_zero and (DE.case == 'base' or + b.degree() > max(0, DE.d.degree() - 1)): + return prde_no_cancel_b_large(b, q, n, DE) + + elif ((b.is_zero or b.degree() < DE.d.degree() - 1) + and (DE.case == 'base' or DE.d.degree() >= 2)): + return prde_no_cancel_b_small(b, q, n, DE) + + elif (DE.d.degree() >= 2 and + b.degree() == DE.d.degree() - 1 and + n > -b.as_poly().LC()/DE.d.as_poly().LC()): + raise NotImplementedError("prde_no_cancel_b_equal() is " + "not yet implemented.") + + else: + # Liouvillian cases + if DE.case in ('primitive', 'exp'): + return prde_cancel_liouvillian(b, q, n, DE) + else: + raise NotImplementedError("non-linear and hypertangent " + "cases have not yet been implemented") + + # else: deg(a) > 0 + + # Iterate SPDE as long as possible cumulating coefficient + # and terms for the recovery of original solutions. + alpha, beta = a.one, [a.zero]*m + while n >= 0: # and a, b relatively prime + a, b, q, r, n = prde_spde(a, b, q, n, DE) + beta = [betai + alpha*ri for betai, ri in zip(beta, r)] + alpha *= a + # Solutions p of a*Dp + b*p = Sum(ci*qi) correspond to + # solutions alpha*p + Sum(ci*betai) of the initial equation. + d = a.gcd(b) + if not d.is_ground: + break + + # a*Dp + b*p = Sum(ci*qi) may have a polynomial solution + # only if the sum is divisible by d. + + qq, M = poly_linear_constraints(q, d) + # qq = [qq1, ..., qqm] where qqi = qi.quo(d). + # M is a matrix with m columns an entries in k. + # Sum(fi*qi, (i, 1, m)), where f1, ..., fm are elements of k, is + # divisible by d if and only if M*Matrix([f1, ..., fm]) == 0, + # in which case the quotient is Sum(fi*qqi). + + A, _ = constant_system(M, zeros(M.rows, 1, DE.t), DE) + # A is a matrix with m columns and entries in Const(k). + # Sum(ci*qqi) is Sum(ci*qi).quo(d), and the remainder is zero + # for c1, ..., cm in Const(k) if and only if + # A*Matrix([c1, ...,cm]) == 0. + + V = A.nullspace() + # V = [v1, ..., vu] where each vj is a column matrix with + # entries aj1, ..., ajm in Const(k). + # Sum(aji*qi) is divisible by d with exact quotient Sum(aji*qqi). + # Sum(ci*qi) is divisible by d if and only if ci = Sum(dj*aji) + # (i = 1, ..., m) for some d1, ..., du in Const(k). + # In that case, solutions of + # a*Dp + b*p = Sum(ci*qi) = Sum(dj*Sum(aji*qi)) + # are the same as those of + # (a/d)*Dp + (b/d)*p = Sum(dj*rj) + # where rj = Sum(aji*qqi). + + if not V: # No non-trivial solution. + return [], eye(m, DE.t) # Could return A, but this has + # the minimum number of rows. + + Mqq = Matrix([qq]) # A single row. + r = [(Mqq*vj)[0] for vj in V] # [r1, ..., ru] + + # Solutions of (a/d)*Dp + (b/d)*p = Sum(dj*rj) correspond to + # solutions alpha*p + Sum(Sum(dj*aji)*betai) of the initial + # equation. These are equal to alpha*p + Sum(dj*fj) where + # fj = Sum(aji*betai). + Mbeta = Matrix([beta]) + f = [(Mbeta*vj)[0] for vj in V] # [f1, ..., fu] + + # + # Solve the reduced equation recursively. + # + g, B = param_poly_rischDE(a.quo(d), b.quo(d), r, n, DE) + + # g = [g1, ..., gv] in k[t]^v and and B is a matrix with u + v + # columns and entries in Const(k) such that + # (a/d)*Dp + (b/d)*p = Sum(dj*rj) has a solution p of degree <= n + # in k[t] if and only if p = Sum(ek*gk) where e1, ..., ev are in + # Const(k) and B*Matrix([d1, ..., du, e1, ..., ev]) == 0. + # The solutions of the original equation are then + # Sum(dj*fj, (j, 1, u)) + alpha*Sum(ek*gk, (k, 1, v)). + + # Collect solution components. + h = f + [alpha*gk for gk in g] + + # Build combined relation matrix. + A = -eye(m, DE.t) + for vj in V: + A = A.row_join(vj) + A = A.row_join(zeros(m, len(g), DE.t)) + A = A.col_join(zeros(B.rows, m, DE.t).row_join(B)) + + return h, A + + +def param_rischDE(fa, fd, G, DE): + """ + Solve a Parametric Risch Differential Equation: Dy + f*y == Sum(ci*Gi, (i, 1, m)). + + Explanation + =========== + + Given a derivation D in k(t), f in k(t), and G + = [G1, ..., Gm] in k(t)^m, return h = [h1, ..., hr] in k(t)^r and + a matrix A with m + r columns and entries in Const(k) such that + Dy + f*y = Sum(ci*Gi, (i, 1, m)) has a solution y + in k(t) with c1, ..., cm in Const(k) if and only if y = Sum(dj*hj, + (j, 1, r)) where d1, ..., dr are in Const(k) and (c1, ..., cm, + d1, ..., dr) is a solution of Ax == 0. + + Elements of k(t) are tuples (a, d) with a and d in k[t]. + """ + m = len(G) + q, (fa, fd) = weak_normalizer(fa, fd, DE) + # Solutions of the weakly normalized equation Dz + f*z = q*Sum(ci*Gi) + # correspond to solutions y = z/q of the original equation. + gamma = q + G = [(q*ga).cancel(gd, include=True) for ga, gd in G] + + a, (ba, bd), G, hn = prde_normal_denom(fa, fd, G, DE) + # Solutions q in k of a*Dq + b*q = Sum(ci*Gi) correspond + # to solutions z = q/hn of the weakly normalized equation. + gamma *= hn + + A, B, G, hs = prde_special_denom(a, ba, bd, G, DE) + # Solutions p in k[t] of A*Dp + B*p = Sum(ci*Gi) correspond + # to solutions q = p/hs of the previous equation. + gamma *= hs + + g = A.gcd(B) + a, b, g = A.quo(g), B.quo(g), [gia.cancel(gid*g, include=True) for + gia, gid in G] + + # a*Dp + b*p = Sum(ci*gi) may have a polynomial solution + # only if the sum is in k[t]. + + q, M = prde_linear_constraints(a, b, g, DE) + + # q = [q1, ..., qm] where qi in k[t] is the polynomial component + # of the partial fraction expansion of gi. + # M is a matrix with m columns and entries in k. + # Sum(fi*gi, (i, 1, m)), where f1, ..., fm are elements of k, + # is a polynomial if and only if M*Matrix([f1, ..., fm]) == 0, + # in which case the sum is equal to Sum(fi*qi). + + M, _ = constant_system(M, zeros(M.rows, 1, DE.t), DE) + # M is a matrix with m columns and entries in Const(k). + # Sum(ci*gi) is in k[t] for c1, ..., cm in Const(k) + # if and only if M*Matrix([c1, ..., cm]) == 0, + # in which case the sum is Sum(ci*qi). + + ## Reduce number of constants at this point + + V = M.nullspace() + # V = [v1, ..., vu] where each vj is a column matrix with + # entries aj1, ..., ajm in Const(k). + # Sum(aji*gi) is in k[t] and equal to Sum(aji*qi) (j = 1, ..., u). + # Sum(ci*gi) is in k[t] if and only is ci = Sum(dj*aji) + # (i = 1, ..., m) for some d1, ..., du in Const(k). + # In that case, + # Sum(ci*gi) = Sum(ci*qi) = Sum(dj*Sum(aji*qi)) = Sum(dj*rj) + # where rj = Sum(aji*qi) (j = 1, ..., u) in k[t]. + + if not V: # No non-trivial solution + return [], eye(m, DE.t) + + Mq = Matrix([q]) # A single row. + r = [(Mq*vj)[0] for vj in V] # [r1, ..., ru] + + # Solutions of a*Dp + b*p = Sum(dj*rj) correspond to solutions + # y = p/gamma of the initial equation with ci = Sum(dj*aji). + + try: + # We try n=5. At least for prde_spde, it will always + # terminate no matter what n is. + n = bound_degree(a, b, r, DE, parametric=True) + except NotImplementedError: + # A temporary bound is set. Eventually, it will be removed. + # the currently added test case takes large time + # even with n=5, and much longer with large n's. + n = 5 + + h, B = param_poly_rischDE(a, b, r, n, DE) + + # h = [h1, ..., hv] in k[t]^v and and B is a matrix with u + v + # columns and entries in Const(k) such that + # a*Dp + b*p = Sum(dj*rj) has a solution p of degree <= n + # in k[t] if and only if p = Sum(ek*hk) where e1, ..., ev are in + # Const(k) and B*Matrix([d1, ..., du, e1, ..., ev]) == 0. + # The solutions of the original equation for ci = Sum(dj*aji) + # (i = 1, ..., m) are then y = Sum(ek*hk, (k, 1, v))/gamma. + + ## Build combined relation matrix with m + u + v columns. + + A = -eye(m, DE.t) + for vj in V: + A = A.row_join(vj) + A = A.row_join(zeros(m, len(h), DE.t)) + A = A.col_join(zeros(B.rows, m, DE.t).row_join(B)) + + ## Eliminate d1, ..., du. + + W = A.nullspace() + + # W = [w1, ..., wt] where each wl is a column matrix with + # entries blk (k = 1, ..., m + u + v) in Const(k). + # The vectors (bl1, ..., blm) generate the space of those + # constant families (c1, ..., cm) for which a solution of + # the equation Dy + f*y == Sum(ci*Gi) exists. They generate + # the space and form a basis except possibly when Dy + f*y == 0 + # is solvable in k(t}. The corresponding solutions are + # y = Sum(blk'*hk, (k, 1, v))/gamma, where k' = k + m + u. + + v = len(h) + shape = (len(W), m+v) + elements = [wl[:m] + wl[-v:] for wl in W] # excise dj's. + items = [e for row in elements for e in row] + + # Need to set the shape in case W is empty + M = Matrix(*shape, items, DE.t) + N = M.nullspace() + + # N = [n1, ..., ns] where the ni in Const(k)^(m + v) are column + # vectors generating the space of linear relations between + # c1, ..., cm, e1, ..., ev. + + C = Matrix([ni[:] for ni in N], DE.t) # rows n1, ..., ns. + + return [hk.cancel(gamma, include=True) for hk in h], C + + +def limited_integrate_reduce(fa, fd, G, DE): + """ + Simpler version of step 1 & 2 for the limited integration problem. + + Explanation + =========== + + Given a derivation D on k(t) and f, g1, ..., gn in k(t), return + (a, b, h, N, g, V) such that a, b, h in k[t], N is a non-negative integer, + g in k(t), V == [v1, ..., vm] in k(t)^m, and for any solution v in k(t), + c1, ..., cm in C of f == Dv + Sum(ci*wi, (i, 1, m)), p = v*h is in k, and + p and the ci satisfy a*Dp + b*p == g + Sum(ci*vi, (i, 1, m)). Furthermore, + if S1irr == Sirr, then p is in k[t], and if t is nonlinear or Liouvillian + over k, then deg(p) <= N. + + So that the special part is always computed, this function calls the more + general prde_special_denom() automatically if it cannot determine that + S1irr == Sirr. Furthermore, it will automatically call bound_degree() when + t is linear and non-Liouvillian, which for the transcendental case, implies + that Dt == a*t + b with for some a, b in k*. + """ + dn, ds = splitfactor(fd, DE) + E = [splitfactor(gd, DE) for _, gd in G] + En, Es = list(zip(*E)) + c = reduce(lambda i, j: i.lcm(j), (dn,) + En) # lcm(dn, en1, ..., enm) + hn = c.gcd(c.diff(DE.t)) + a = hn + b = -derivation(hn, DE) + N = 0 + + # These are the cases where we know that S1irr = Sirr, but there could be + # others, and this algorithm will need to be extended to handle them. + if DE.case in ('base', 'primitive', 'exp', 'tan'): + hs = reduce(lambda i, j: i.lcm(j), (ds,) + Es) # lcm(ds, es1, ..., esm) + a = hn*hs + b -= (hn*derivation(hs, DE)).quo(hs) + mu = min(order_at_oo(fa, fd, DE.t), min(order_at_oo(ga, gd, DE.t) for + ga, gd in G)) + # So far, all the above are also nonlinear or Liouvillian, but if this + # changes, then this will need to be updated to call bound_degree() + # as per the docstring of this function (DE.case == 'other_linear'). + N = hn.degree(DE.t) + hs.degree(DE.t) + max(0, 1 - DE.d.degree(DE.t) - mu) + else: + # TODO: implement this + raise NotImplementedError + + V = [(-a*hn*ga).cancel(gd, include=True) for ga, gd in G] + return (a, b, a, N, (a*hn*fa).cancel(fd, include=True), V) + + +def limited_integrate(fa, fd, G, DE): + """ + Solves the limited integration problem: f = Dv + Sum(ci*wi, (i, 1, n)) + """ + fa, fd = fa*Poly(1/fd.LC(), DE.t), fd.monic() + # interpreting limited integration problem as a + # parametric Risch DE problem + Fa = Poly(0, DE.t) + Fd = Poly(1, DE.t) + G = [(fa, fd)] + G + h, A = param_rischDE(Fa, Fd, G, DE) + V = A.nullspace() + V = [v for v in V if v[0] != 0] + if not V: + return None + else: + # we can take any vector from V, we take V[0] + c0 = V[0][0] + # v = [-1, c1, ..., cm, d1, ..., dr] + v = V[0]/(-c0) + r = len(h) + m = len(v) - r - 1 + C = list(v[1: m + 1]) + y = -sum(v[m + 1 + i]*h[i][0].as_expr()/h[i][1].as_expr() \ + for i in range(r)) + y_num, y_den = y.as_numer_denom() + Ya, Yd = Poly(y_num, DE.t), Poly(y_den, DE.t) + Y = Ya*Poly(1/Yd.LC(), DE.t), Yd.monic() + return Y, C + + +def parametric_log_deriv_heu(fa, fd, wa, wd, DE, c1=None): + """ + Parametric logarithmic derivative heuristic. + + Explanation + =========== + + Given a derivation D on k[t], f in k(t), and a hyperexponential monomial + theta over k(t), raises either NotImplementedError, in which case the + heuristic failed, or returns None, in which case it has proven that no + solution exists, or returns a solution (n, m, v) of the equation + n*f == Dv/v + m*Dtheta/theta, with v in k(t)* and n, m in ZZ with n != 0. + + If this heuristic fails, the structure theorem approach will need to be + used. + + The argument w == Dtheta/theta + """ + # TODO: finish writing this and write tests + c1 = c1 or Dummy('c1') + + p, a = fa.div(fd) + q, b = wa.div(wd) + + B = max(0, derivation(DE.t, DE).degree(DE.t) - 1) + C = max(p.degree(DE.t), q.degree(DE.t)) + + if q.degree(DE.t) > B: + eqs = [p.nth(i) - c1*q.nth(i) for i in range(B + 1, C + 1)] + s = solve(eqs, c1) + if not s or not s[c1].is_Rational: + # deg(q) > B, no solution for c. + return None + + M, N = s[c1].as_numer_denom() + M_poly = M.as_poly(q.gens) + N_poly = N.as_poly(q.gens) + + nfmwa = N_poly*fa*wd - M_poly*wa*fd + nfmwd = fd*wd + Qv = is_log_deriv_k_t_radical_in_field(nfmwa, nfmwd, DE, 'auto') + if Qv is None: + # (N*f - M*w) is not the logarithmic derivative of a k(t)-radical. + return None + + Q, v = Qv + + if Q.is_zero or v.is_zero: + return None + + return (Q*N, Q*M, v) + + if p.degree(DE.t) > B: + return None + + c = lcm(fd.as_poly(DE.t).LC(), wd.as_poly(DE.t).LC()) + l = fd.monic().lcm(wd.monic())*Poly(c, DE.t) + ln, ls = splitfactor(l, DE) + z = ls*ln.gcd(ln.diff(DE.t)) + + if not z.has(DE.t): + # TODO: We treat this as 'no solution', until the structure + # theorem version of parametric_log_deriv is implemented. + return None + + u1, r1 = (fa*l.quo(fd)).div(z) # (l*f).div(z) + u2, r2 = (wa*l.quo(wd)).div(z) # (l*w).div(z) + + eqs = [r1.nth(i) - c1*r2.nth(i) for i in range(z.degree(DE.t))] + s = solve(eqs, c1) + if not s or not s[c1].is_Rational: + # deg(q) <= B, no solution for c. + return None + + M, N = s[c1].as_numer_denom() + + nfmwa = N.as_poly(DE.t)*fa*wd - M.as_poly(DE.t)*wa*fd + nfmwd = fd*wd + Qv = is_log_deriv_k_t_radical_in_field(nfmwa, nfmwd, DE) + if Qv is None: + # (N*f - M*w) is not the logarithmic derivative of a k(t)-radical. + return None + + Q, v = Qv + + if Q.is_zero or v.is_zero: + return None + + return (Q*N, Q*M, v) + + +def parametric_log_deriv(fa, fd, wa, wd, DE): + # TODO: Write the full algorithm using the structure theorems. +# try: + A = parametric_log_deriv_heu(fa, fd, wa, wd, DE) +# except NotImplementedError: + # Heuristic failed, we have to use the full method. + # TODO: This could be implemented more efficiently. + # It isn't too worrisome, because the heuristic handles most difficult + # cases. + return A + + +def is_deriv_k(fa, fd, DE): + r""" + Checks if Df/f is the derivative of an element of k(t). + + Explanation + =========== + + a in k(t) is the derivative of an element of k(t) if there exists b in k(t) + such that a = Db. Either returns (ans, u), such that Df/f == Du, or None, + which means that Df/f is not the derivative of an element of k(t). ans is + a list of tuples such that Add(*[i*j for i, j in ans]) == u. This is useful + for seeing exactly which elements of k(t) produce u. + + This function uses the structure theorem approach, which says that for any + f in K, Df/f is the derivative of a element of K if and only if there are ri + in QQ such that:: + + --- --- Dt + \ r * Dt + \ r * i Df + / i i / i --- = --. + --- --- t f + i in L i in E i + K/C(x) K/C(x) + + + Where C = Const(K), L_K/C(x) = { i in {1, ..., n} such that t_i is + transcendental over C(x)(t_1, ..., t_i-1) and Dt_i = Da_i/a_i, for some a_i + in C(x)(t_1, ..., t_i-1)* } (i.e., the set of all indices of logarithmic + monomials of K over C(x)), and E_K/C(x) = { i in {1, ..., n} such that t_i + is transcendental over C(x)(t_1, ..., t_i-1) and Dt_i/t_i = Da_i, for some + a_i in C(x)(t_1, ..., t_i-1) } (i.e., the set of all indices of + hyperexponential monomials of K over C(x)). If K is an elementary extension + over C(x), then the cardinality of L_K/C(x) U E_K/C(x) is exactly the + transcendence degree of K over C(x). Furthermore, because Const_D(K) == + Const_D(C(x)) == C, deg(Dt_i) == 1 when t_i is in E_K/C(x) and + deg(Dt_i) == 0 when t_i is in L_K/C(x), implying in particular that E_K/C(x) + and L_K/C(x) are disjoint. + + The sets L_K/C(x) and E_K/C(x) must, by their nature, be computed + recursively using this same function. Therefore, it is required to pass + them as indices to D (or T). E_args are the arguments of the + hyperexponentials indexed by E_K (i.e., if i is in E_K, then T[i] == + exp(E_args[i])). This is needed to compute the final answer u such that + Df/f == Du. + + log(f) will be the same as u up to a additive constant. This is because + they will both behave the same as monomials. For example, both log(x) and + log(2*x) == log(x) + log(2) satisfy Dt == 1/x, because log(2) is constant. + Therefore, the term const is returned. const is such that + log(const) + f == u. This is calculated by dividing the arguments of one + logarithm from the other. Therefore, it is necessary to pass the arguments + of the logarithmic terms in L_args. + + To handle the case where we are given Df/f, not f, use is_deriv_k_in_field(). + + See also + ======== + is_log_deriv_k_t_radical_in_field, is_log_deriv_k_t_radical + + """ + # Compute Df/f + dfa, dfd = (fd*derivation(fa, DE) - fa*derivation(fd, DE)), fd*fa + dfa, dfd = dfa.cancel(dfd, include=True) + + # Our assumption here is that each monomial is recursively transcendental + if len(DE.exts) != len(DE.D): + if [i for i in DE.cases if i == 'tan'] or \ + ({i for i in DE.cases if i == 'primitive'} - + set(DE.indices('log'))): + raise NotImplementedError("Real version of the structure " + "theorems with hypertangent support is not yet implemented.") + + # TODO: What should really be done in this case? + raise NotImplementedError("Nonelementary extensions not supported " + "in the structure theorems.") + + E_part = [DE.D[i].quo(Poly(DE.T[i], DE.T[i])).as_expr() for i in DE.indices('exp')] + L_part = [DE.D[i].as_expr() for i in DE.indices('log')] + + # The expression dfa/dfd might not be polynomial in any of its symbols so we + # use a Dummy as the generator for PolyMatrix. + dum = Dummy() + lhs = Matrix([E_part + L_part], dum) + rhs = Matrix([dfa.as_expr()/dfd.as_expr()], dum) + + A, u = constant_system(lhs, rhs, DE) + + u = u.to_Matrix() # Poly to Expr + + if not A or not all(derivation(i, DE, basic=True).is_zero for i in u): + # If the elements of u are not all constant + # Note: See comment in constant_system + + # Also note: derivation(basic=True) calls cancel() + return None + else: + if not all(i.is_Rational for i in u): + raise NotImplementedError("Cannot work with non-rational " + "coefficients in this case.") + else: + terms = ([DE.extargs[i] for i in DE.indices('exp')] + + [DE.T[i] for i in DE.indices('log')]) + ans = list(zip(terms, u)) + result = Add(*[Mul(i, j) for i, j in ans]) + argterms = ([DE.T[i] for i in DE.indices('exp')] + + [DE.extargs[i] for i in DE.indices('log')]) + l = [] + ld = [] + for i, j in zip(argterms, u): + # We need to get around things like sqrt(x**2) != x + # and also sqrt(x**2 + 2*x + 1) != x + 1 + # Issue 10798: i need not be a polynomial + i, d = i.as_numer_denom() + icoeff, iterms = sqf_list(i) + l.append(Mul(*([Pow(icoeff, j)] + [Pow(b, e*j) for b, e in iterms]))) + dcoeff, dterms = sqf_list(d) + ld.append(Mul(*([Pow(dcoeff, j)] + [Pow(b, e*j) for b, e in dterms]))) + const = cancel(fa.as_expr()/fd.as_expr()/Mul(*l)*Mul(*ld)) + + return (ans, result, const) + + +def is_log_deriv_k_t_radical(fa, fd, DE, Df=True): + r""" + Checks if Df is the logarithmic derivative of a k(t)-radical. + + Explanation + =========== + + b in k(t) can be written as the logarithmic derivative of a k(t) radical if + there exist n in ZZ and u in k(t) with n, u != 0 such that n*b == Du/u. + Either returns (ans, u, n, const) or None, which means that Df cannot be + written as the logarithmic derivative of a k(t)-radical. ans is a list of + tuples such that Mul(*[i**j for i, j in ans]) == u. This is useful for + seeing exactly what elements of k(t) produce u. + + This function uses the structure theorem approach, which says that for any + f in K, Df is the logarithmic derivative of a K-radical if and only if there + are ri in QQ such that:: + + --- --- Dt + \ r * Dt + \ r * i + / i i / i --- = Df. + --- --- t + i in L i in E i + K/C(x) K/C(x) + + + Where C = Const(K), L_K/C(x) = { i in {1, ..., n} such that t_i is + transcendental over C(x)(t_1, ..., t_i-1) and Dt_i = Da_i/a_i, for some a_i + in C(x)(t_1, ..., t_i-1)* } (i.e., the set of all indices of logarithmic + monomials of K over C(x)), and E_K/C(x) = { i in {1, ..., n} such that t_i + is transcendental over C(x)(t_1, ..., t_i-1) and Dt_i/t_i = Da_i, for some + a_i in C(x)(t_1, ..., t_i-1) } (i.e., the set of all indices of + hyperexponential monomials of K over C(x)). If K is an elementary extension + over C(x), then the cardinality of L_K/C(x) U E_K/C(x) is exactly the + transcendence degree of K over C(x). Furthermore, because Const_D(K) == + Const_D(C(x)) == C, deg(Dt_i) == 1 when t_i is in E_K/C(x) and + deg(Dt_i) == 0 when t_i is in L_K/C(x), implying in particular that E_K/C(x) + and L_K/C(x) are disjoint. + + The sets L_K/C(x) and E_K/C(x) must, by their nature, be computed + recursively using this same function. Therefore, it is required to pass + them as indices to D (or T). L_args are the arguments of the logarithms + indexed by L_K (i.e., if i is in L_K, then T[i] == log(L_args[i])). This is + needed to compute the final answer u such that n*f == Du/u. + + exp(f) will be the same as u up to a multiplicative constant. This is + because they will both behave the same as monomials. For example, both + exp(x) and exp(x + 1) == E*exp(x) satisfy Dt == t. Therefore, the term const + is returned. const is such that exp(const)*f == u. This is calculated by + subtracting the arguments of one exponential from the other. Therefore, it + is necessary to pass the arguments of the exponential terms in E_args. + + To handle the case where we are given Df, not f, use + is_log_deriv_k_t_radical_in_field(). + + See also + ======== + + is_log_deriv_k_t_radical_in_field, is_deriv_k + + """ + if Df: + dfa, dfd = (fd*derivation(fa, DE) - fa*derivation(fd, DE)).cancel(fd**2, + include=True) + else: + dfa, dfd = fa, fd + + # Our assumption here is that each monomial is recursively transcendental + if len(DE.exts) != len(DE.D): + if [i for i in DE.cases if i == 'tan'] or \ + ({i for i in DE.cases if i == 'primitive'} - + set(DE.indices('log'))): + raise NotImplementedError("Real version of the structure " + "theorems with hypertangent support is not yet implemented.") + + # TODO: What should really be done in this case? + raise NotImplementedError("Nonelementary extensions not supported " + "in the structure theorems.") + + E_part = [DE.D[i].quo(Poly(DE.T[i], DE.T[i])).as_expr() for i in DE.indices('exp')] + L_part = [DE.D[i].as_expr() for i in DE.indices('log')] + + # The expression dfa/dfd might not be polynomial in any of its symbols so we + # use a Dummy as the generator for PolyMatrix. + dum = Dummy() + lhs = Matrix([E_part + L_part], dum) + rhs = Matrix([dfa.as_expr()/dfd.as_expr()], dum) + + A, u = constant_system(lhs, rhs, DE) + + u = u.to_Matrix() # Poly to Expr + + if not A or not all(derivation(i, DE, basic=True).is_zero for i in u): + # If the elements of u are not all constant + # Note: See comment in constant_system + + # Also note: derivation(basic=True) calls cancel() + return None + else: + if not all(i.is_Rational for i in u): + # TODO: But maybe we can tell if they're not rational, like + # log(2)/log(3). Also, there should be an option to continue + # anyway, even if the result might potentially be wrong. + raise NotImplementedError("Cannot work with non-rational " + "coefficients in this case.") + else: + n = S.One*reduce(ilcm, [i.as_numer_denom()[1] for i in u]) + u *= n + terms = ([DE.T[i] for i in DE.indices('exp')] + + [DE.extargs[i] for i in DE.indices('log')]) + ans = list(zip(terms, u)) + result = Mul(*[Pow(i, j) for i, j in ans]) + + # exp(f) will be the same as result up to a multiplicative + # constant. We now find the log of that constant. + argterms = ([DE.extargs[i] for i in DE.indices('exp')] + + [DE.T[i] for i in DE.indices('log')]) + const = cancel(fa.as_expr()/fd.as_expr() - + Add(*[Mul(i, j/n) for i, j in zip(argterms, u)])) + + return (ans, result, n, const) + + +def is_log_deriv_k_t_radical_in_field(fa, fd, DE, case='auto', z=None): + """ + Checks if f can be written as the logarithmic derivative of a k(t)-radical. + + Explanation + =========== + + It differs from is_log_deriv_k_t_radical(fa, fd, DE, Df=False) + for any given fa, fd, DE in that it finds the solution in the + given field not in some (possibly unspecified extension) and + "in_field" with the function name is used to indicate that. + + f in k(t) can be written as the logarithmic derivative of a k(t) radical if + there exist n in ZZ and u in k(t) with n, u != 0 such that n*f == Du/u. + Either returns (n, u) or None, which means that f cannot be written as the + logarithmic derivative of a k(t)-radical. + + case is one of {'primitive', 'exp', 'tan', 'auto'} for the primitive, + hyperexponential, and hypertangent cases, respectively. If case is 'auto', + it will attempt to determine the type of the derivation automatically. + + See also + ======== + is_log_deriv_k_t_radical, is_deriv_k + + """ + fa, fd = fa.cancel(fd, include=True) + + # f must be simple + n, s = splitfactor(fd, DE) + if not s.is_one: + pass + + z = z or Dummy('z') + H, b = residue_reduce(fa, fd, DE, z=z) + if not b: + # I will have to verify, but I believe that the answer should be + # None in this case. This should never happen for the + # functions given when solving the parametric logarithmic + # derivative problem when integration elementary functions (see + # Bronstein's book, page 255), so most likely this indicates a bug. + return None + + roots = [(i, i.real_roots()) for i, _ in H] + if not all(len(j) == i.degree() and all(k.is_Rational for k in j) for + i, j in roots): + # If f is the logarithmic derivative of a k(t)-radical, then all the + # roots of the resultant must be rational numbers. + return None + + # [(a, i), ...], where i*log(a) is a term in the log-part of the integral + # of f + respolys, residues = list(zip(*roots)) or [[], []] + # Note: this might be empty, but everything below should work find in that + # case (it should be the same as if it were [[1, 1]]) + residueterms = [(H[j][1].subs(z, i), i) for j in range(len(H)) for + i in residues[j]] + + # TODO: finish writing this and write tests + + p = cancel(fa.as_expr()/fd.as_expr() - residue_reduce_derivation(H, DE, z)) + + p = p.as_poly(DE.t) + if p is None: + # f - Dg will be in k[t] if f is the logarithmic derivative of a k(t)-radical + return None + + if p.degree(DE.t) >= max(1, DE.d.degree(DE.t)): + return None + + if case == 'auto': + case = DE.case + + if case == 'exp': + wa, wd = derivation(DE.t, DE).cancel(Poly(DE.t, DE.t), include=True) + with DecrementLevel(DE): + pa, pd = frac_in(p, DE.t, cancel=True) + wa, wd = frac_in((wa, wd), DE.t) + A = parametric_log_deriv(pa, pd, wa, wd, DE) + if A is None: + return None + n, e, u = A + u *= DE.t**e + + elif case == 'primitive': + with DecrementLevel(DE): + pa, pd = frac_in(p, DE.t) + A = is_log_deriv_k_t_radical_in_field(pa, pd, DE, case='auto') + if A is None: + return None + n, u = A + + elif case == 'base': + # TODO: we can use more efficient residue reduction from ratint() + if not fd.is_sqf or fa.degree() >= fd.degree(): + # f is the logarithmic derivative in the base case if and only if + # f = fa/fd, fd is square-free, deg(fa) < deg(fd), and + # gcd(fa, fd) == 1. The last condition is handled by cancel() above. + return None + # Note: if residueterms = [], returns (1, 1) + # f had better be 0 in that case. + n = S.One*reduce(ilcm, [i.as_numer_denom()[1] for _, i in residueterms], 1) + u = Mul(*[Pow(i, j*n) for i, j in residueterms]) + return (n, u) + + elif case == 'tan': + raise NotImplementedError("The hypertangent case is " + "not yet implemented for is_log_deriv_k_t_radical_in_field()") + + elif case in ('other_linear', 'other_nonlinear'): + # XXX: If these are supported by the structure theorems, change to NotImplementedError. + raise ValueError("The %s case is not supported in this function." % case) + + else: + raise ValueError("case must be one of {'primitive', 'exp', 'tan', " + "'base', 'auto'}, not %s" % case) + + common_denom = S.One*reduce(ilcm, [i.as_numer_denom()[1] for i in [j for _, j in + residueterms]] + [n], 1) + residueterms = [(i, j*common_denom) for i, j in residueterms] + m = common_denom//n + if common_denom != n*m: # Verify exact division + raise ValueError("Inexact division") + u = cancel(u**m*Mul(*[Pow(i, j) for i, j in residueterms])) + + return (common_denom, u) diff --git a/.venv/lib/python3.13/site-packages/sympy/integrals/quadrature.py b/.venv/lib/python3.13/site-packages/sympy/integrals/quadrature.py new file mode 100644 index 0000000000000000000000000000000000000000..b518bd427dc9980d6a941d2e1ef4d139c5f0f5f9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/integrals/quadrature.py @@ -0,0 +1,617 @@ +from sympy.core import S, Dummy, pi +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.trigonometric import sin, cos +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.special.gamma_functions import gamma +from sympy.polys.orthopolys import (legendre_poly, laguerre_poly, + hermite_poly, jacobi_poly) +from sympy.polys.rootoftools import RootOf + + +def gauss_legendre(n, n_digits): + r""" + Computes the Gauss-Legendre quadrature [1]_ points and weights. + + Explanation + =========== + + The Gauss-Legendre quadrature approximates the integral: + + .. math:: + \int_{-1}^1 f(x)\,dx \approx \sum_{i=1}^n w_i f(x_i) + + The nodes `x_i` of an order `n` quadrature rule are the roots of `P_n` + and the weights `w_i` are given by: + + .. math:: + w_i = \frac{2}{\left(1-x_i^2\right) \left(P'_n(x_i)\right)^2} + + Parameters + ========== + + n : + The order of quadrature. + n_digits : + Number of significant digits of the points and weights to return. + + Returns + ======= + + (x, w) : the ``x`` and ``w`` are lists of points and weights as Floats. + The points `x_i` and weights `w_i` are returned as ``(x, w)`` + tuple of lists. + + Examples + ======== + + >>> from sympy.integrals.quadrature import gauss_legendre + >>> x, w = gauss_legendre(3, 5) + >>> x + [-0.7746, 0, 0.7746] + >>> w + [0.55556, 0.88889, 0.55556] + >>> x, w = gauss_legendre(4, 5) + >>> x + [-0.86114, -0.33998, 0.33998, 0.86114] + >>> w + [0.34785, 0.65215, 0.65215, 0.34785] + + See Also + ======== + + gauss_laguerre, gauss_gen_laguerre, gauss_hermite, gauss_chebyshev_t, gauss_chebyshev_u, gauss_jacobi, gauss_lobatto + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Gaussian_quadrature + .. [2] https://people.sc.fsu.edu/~jburkardt/cpp_src/legendre_rule/legendre_rule.html + """ + x = Dummy("x") + p = legendre_poly(n, x, polys=True) + pd = p.diff(x) + xi = [] + w = [] + for r in p.real_roots(): + if isinstance(r, RootOf): + r = r.eval_rational(S.One/10**(n_digits+2)) + xi.append(r.n(n_digits)) + w.append((2/((1-r**2) * pd.subs(x, r)**2)).n(n_digits)) + return xi, w + + +def gauss_laguerre(n, n_digits): + r""" + Computes the Gauss-Laguerre quadrature [1]_ points and weights. + + Explanation + =========== + + The Gauss-Laguerre quadrature approximates the integral: + + .. math:: + \int_0^{\infty} e^{-x} f(x)\,dx \approx \sum_{i=1}^n w_i f(x_i) + + + The nodes `x_i` of an order `n` quadrature rule are the roots of `L_n` + and the weights `w_i` are given by: + + .. math:: + w_i = \frac{x_i}{(n+1)^2 \left(L_{n+1}(x_i)\right)^2} + + Parameters + ========== + + n : + The order of quadrature. + n_digits : + Number of significant digits of the points and weights to return. + + Returns + ======= + + (x, w) : The ``x`` and ``w`` are lists of points and weights as Floats. + The points `x_i` and weights `w_i` are returned as ``(x, w)`` + tuple of lists. + + Examples + ======== + + >>> from sympy.integrals.quadrature import gauss_laguerre + >>> x, w = gauss_laguerre(3, 5) + >>> x + [0.41577, 2.2943, 6.2899] + >>> w + [0.71109, 0.27852, 0.010389] + >>> x, w = gauss_laguerre(6, 5) + >>> x + [0.22285, 1.1889, 2.9927, 5.7751, 9.8375, 15.983] + >>> w + [0.45896, 0.417, 0.11337, 0.010399, 0.00026102, 8.9855e-7] + + See Also + ======== + + gauss_legendre, gauss_gen_laguerre, gauss_hermite, gauss_chebyshev_t, gauss_chebyshev_u, gauss_jacobi, gauss_lobatto + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Gauss%E2%80%93Laguerre_quadrature + .. [2] https://people.sc.fsu.edu/~jburkardt/cpp_src/laguerre_rule/laguerre_rule.html + """ + x = Dummy("x") + p = laguerre_poly(n, x, polys=True) + p1 = laguerre_poly(n+1, x, polys=True) + xi = [] + w = [] + for r in p.real_roots(): + if isinstance(r, RootOf): + r = r.eval_rational(S.One/10**(n_digits+2)) + xi.append(r.n(n_digits)) + w.append((r/((n+1)**2 * p1.subs(x, r)**2)).n(n_digits)) + return xi, w + + +def gauss_hermite(n, n_digits): + r""" + Computes the Gauss-Hermite quadrature [1]_ points and weights. + + Explanation + =========== + + The Gauss-Hermite quadrature approximates the integral: + + .. math:: + \int_{-\infty}^{\infty} e^{-x^2} f(x)\,dx \approx + \sum_{i=1}^n w_i f(x_i) + + The nodes `x_i` of an order `n` quadrature rule are the roots of `H_n` + and the weights `w_i` are given by: + + .. math:: + w_i = \frac{2^{n-1} n! \sqrt{\pi}}{n^2 \left(H_{n-1}(x_i)\right)^2} + + Parameters + ========== + + n : + The order of quadrature. + n_digits : + Number of significant digits of the points and weights to return. + + Returns + ======= + + (x, w) : The ``x`` and ``w`` are lists of points and weights as Floats. + The points `x_i` and weights `w_i` are returned as ``(x, w)`` + tuple of lists. + + Examples + ======== + + >>> from sympy.integrals.quadrature import gauss_hermite + >>> x, w = gauss_hermite(3, 5) + >>> x + [-1.2247, 0, 1.2247] + >>> w + [0.29541, 1.1816, 0.29541] + + >>> x, w = gauss_hermite(6, 5) + >>> x + [-2.3506, -1.3358, -0.43608, 0.43608, 1.3358, 2.3506] + >>> w + [0.00453, 0.15707, 0.72463, 0.72463, 0.15707, 0.00453] + + See Also + ======== + + gauss_legendre, gauss_laguerre, gauss_gen_laguerre, gauss_chebyshev_t, gauss_chebyshev_u, gauss_jacobi, gauss_lobatto + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Gauss-Hermite_Quadrature + .. [2] https://people.sc.fsu.edu/~jburkardt/cpp_src/hermite_rule/hermite_rule.html + .. [3] https://people.sc.fsu.edu/~jburkardt/cpp_src/gen_hermite_rule/gen_hermite_rule.html + """ + x = Dummy("x") + p = hermite_poly(n, x, polys=True) + p1 = hermite_poly(n-1, x, polys=True) + xi = [] + w = [] + for r in p.real_roots(): + if isinstance(r, RootOf): + r = r.eval_rational(S.One/10**(n_digits+2)) + xi.append(r.n(n_digits)) + w.append(((2**(n-1) * factorial(n) * sqrt(pi)) / + (n**2 * p1.subs(x, r)**2)).n(n_digits)) + return xi, w + + +def gauss_gen_laguerre(n, alpha, n_digits): + r""" + Computes the generalized Gauss-Laguerre quadrature [1]_ points and weights. + + Explanation + =========== + + The generalized Gauss-Laguerre quadrature approximates the integral: + + .. math:: + \int_{0}^\infty x^{\alpha} e^{-x} f(x)\,dx \approx + \sum_{i=1}^n w_i f(x_i) + + The nodes `x_i` of an order `n` quadrature rule are the roots of + `L^{\alpha}_n` and the weights `w_i` are given by: + + .. math:: + w_i = \frac{\Gamma(\alpha+n)} + {n \Gamma(n) L^{\alpha}_{n-1}(x_i) L^{\alpha+1}_{n-1}(x_i)} + + Parameters + ========== + + n : + The order of quadrature. + + alpha : + The exponent of the singularity, `\alpha > -1`. + + n_digits : + Number of significant digits of the points and weights to return. + + Returns + ======= + + (x, w) : the ``x`` and ``w`` are lists of points and weights as Floats. + The points `x_i` and weights `w_i` are returned as ``(x, w)`` + tuple of lists. + + Examples + ======== + + >>> from sympy import S + >>> from sympy.integrals.quadrature import gauss_gen_laguerre + >>> x, w = gauss_gen_laguerre(3, -S.Half, 5) + >>> x + [0.19016, 1.7845, 5.5253] + >>> w + [1.4493, 0.31413, 0.00906] + + >>> x, w = gauss_gen_laguerre(4, 3*S.Half, 5) + >>> x + [0.97851, 2.9904, 6.3193, 11.712] + >>> w + [0.53087, 0.67721, 0.11895, 0.0023152] + + See Also + ======== + + gauss_legendre, gauss_laguerre, gauss_hermite, gauss_chebyshev_t, gauss_chebyshev_u, gauss_jacobi, gauss_lobatto + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Gauss%E2%80%93Laguerre_quadrature + .. [2] https://people.sc.fsu.edu/~jburkardt/cpp_src/gen_laguerre_rule/gen_laguerre_rule.html + """ + x = Dummy("x") + p = laguerre_poly(n, x, alpha=alpha, polys=True) + p1 = laguerre_poly(n-1, x, alpha=alpha, polys=True) + p2 = laguerre_poly(n-1, x, alpha=alpha+1, polys=True) + xi = [] + w = [] + for r in p.real_roots(): + if isinstance(r, RootOf): + r = r.eval_rational(S.One/10**(n_digits+2)) + xi.append(r.n(n_digits)) + w.append((gamma(alpha+n) / + (n*gamma(n)*p1.subs(x, r)*p2.subs(x, r))).n(n_digits)) + return xi, w + + +def gauss_chebyshev_t(n, n_digits): + r""" + Computes the Gauss-Chebyshev quadrature [1]_ points and weights of + the first kind. + + Explanation + =========== + + The Gauss-Chebyshev quadrature of the first kind approximates the integral: + + .. math:: + \int_{-1}^{1} \frac{1}{\sqrt{1-x^2}} f(x)\,dx \approx + \sum_{i=1}^n w_i f(x_i) + + The nodes `x_i` of an order `n` quadrature rule are the roots of `T_n` + and the weights `w_i` are given by: + + .. math:: + w_i = \frac{\pi}{n} + + Parameters + ========== + + n : + The order of quadrature. + + n_digits : + Number of significant digits of the points and weights to return. + + Returns + ======= + + (x, w) : the ``x`` and ``w`` are lists of points and weights as Floats. + The points `x_i` and weights `w_i` are returned as ``(x, w)`` + tuple of lists. + + Examples + ======== + + >>> from sympy.integrals.quadrature import gauss_chebyshev_t + >>> x, w = gauss_chebyshev_t(3, 5) + >>> x + [0.86602, 0, -0.86602] + >>> w + [1.0472, 1.0472, 1.0472] + + >>> x, w = gauss_chebyshev_t(6, 5) + >>> x + [0.96593, 0.70711, 0.25882, -0.25882, -0.70711, -0.96593] + >>> w + [0.5236, 0.5236, 0.5236, 0.5236, 0.5236, 0.5236] + + See Also + ======== + + gauss_legendre, gauss_laguerre, gauss_hermite, gauss_gen_laguerre, gauss_chebyshev_u, gauss_jacobi, gauss_lobatto + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Chebyshev%E2%80%93Gauss_quadrature + .. [2] https://people.sc.fsu.edu/~jburkardt/cpp_src/chebyshev1_rule/chebyshev1_rule.html + """ + xi = [] + w = [] + for i in range(1, n+1): + xi.append((cos((2*i-S.One)/(2*n)*S.Pi)).n(n_digits)) + w.append((S.Pi/n).n(n_digits)) + return xi, w + + +def gauss_chebyshev_u(n, n_digits): + r""" + Computes the Gauss-Chebyshev quadrature [1]_ points and weights of + the second kind. + + Explanation + =========== + + The Gauss-Chebyshev quadrature of the second kind approximates the + integral: + + .. math:: + \int_{-1}^{1} \sqrt{1-x^2} f(x)\,dx \approx \sum_{i=1}^n w_i f(x_i) + + The nodes `x_i` of an order `n` quadrature rule are the roots of `U_n` + and the weights `w_i` are given by: + + .. math:: + w_i = \frac{\pi}{n+1} \sin^2 \left(\frac{i}{n+1}\pi\right) + + Parameters + ========== + + n : the order of quadrature + + n_digits : number of significant digits of the points and weights to return + + Returns + ======= + + (x, w) : the ``x`` and ``w`` are lists of points and weights as Floats. + The points `x_i` and weights `w_i` are returned as ``(x, w)`` + tuple of lists. + + Examples + ======== + + >>> from sympy.integrals.quadrature import gauss_chebyshev_u + >>> x, w = gauss_chebyshev_u(3, 5) + >>> x + [0.70711, 0, -0.70711] + >>> w + [0.3927, 0.7854, 0.3927] + + >>> x, w = gauss_chebyshev_u(6, 5) + >>> x + [0.90097, 0.62349, 0.22252, -0.22252, -0.62349, -0.90097] + >>> w + [0.084489, 0.27433, 0.42658, 0.42658, 0.27433, 0.084489] + + See Also + ======== + + gauss_legendre, gauss_laguerre, gauss_hermite, gauss_gen_laguerre, gauss_chebyshev_t, gauss_jacobi, gauss_lobatto + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Chebyshev%E2%80%93Gauss_quadrature + .. [2] https://people.sc.fsu.edu/~jburkardt/cpp_src/chebyshev2_rule/chebyshev2_rule.html + """ + xi = [] + w = [] + for i in range(1, n+1): + xi.append((cos(i/(n+S.One)*S.Pi)).n(n_digits)) + w.append((S.Pi/(n+S.One)*sin(i*S.Pi/(n+S.One))**2).n(n_digits)) + return xi, w + + +def gauss_jacobi(n, alpha, beta, n_digits): + r""" + Computes the Gauss-Jacobi quadrature [1]_ points and weights. + + Explanation + =========== + + The Gauss-Jacobi quadrature of the first kind approximates the integral: + + .. math:: + \int_{-1}^1 (1-x)^\alpha (1+x)^\beta f(x)\,dx \approx + \sum_{i=1}^n w_i f(x_i) + + The nodes `x_i` of an order `n` quadrature rule are the roots of + `P^{(\alpha,\beta)}_n` and the weights `w_i` are given by: + + .. math:: + w_i = -\frac{2n+\alpha+\beta+2}{n+\alpha+\beta+1} + \frac{\Gamma(n+\alpha+1)\Gamma(n+\beta+1)} + {\Gamma(n+\alpha+\beta+1)(n+1)!} + \frac{2^{\alpha+\beta}}{P'_n(x_i) + P^{(\alpha,\beta)}_{n+1}(x_i)} + + Parameters + ========== + + n : the order of quadrature + + alpha : the first parameter of the Jacobi Polynomial, `\alpha > -1` + + beta : the second parameter of the Jacobi Polynomial, `\beta > -1` + + n_digits : number of significant digits of the points and weights to return + + Returns + ======= + + (x, w) : the ``x`` and ``w`` are lists of points and weights as Floats. + The points `x_i` and weights `w_i` are returned as ``(x, w)`` + tuple of lists. + + Examples + ======== + + >>> from sympy import S + >>> from sympy.integrals.quadrature import gauss_jacobi + >>> x, w = gauss_jacobi(3, S.Half, -S.Half, 5) + >>> x + [-0.90097, -0.22252, 0.62349] + >>> w + [1.7063, 1.0973, 0.33795] + + >>> x, w = gauss_jacobi(6, 1, 1, 5) + >>> x + [-0.87174, -0.5917, -0.2093, 0.2093, 0.5917, 0.87174] + >>> w + [0.050584, 0.22169, 0.39439, 0.39439, 0.22169, 0.050584] + + See Also + ======== + + gauss_legendre, gauss_laguerre, gauss_hermite, gauss_gen_laguerre, + gauss_chebyshev_t, gauss_chebyshev_u, gauss_lobatto + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Gauss%E2%80%93Jacobi_quadrature + .. [2] https://people.sc.fsu.edu/~jburkardt/cpp_src/jacobi_rule/jacobi_rule.html + .. [3] https://people.sc.fsu.edu/~jburkardt/cpp_src/gegenbauer_rule/gegenbauer_rule.html + """ + x = Dummy("x") + p = jacobi_poly(n, alpha, beta, x, polys=True) + pd = p.diff(x) + pn = jacobi_poly(n+1, alpha, beta, x, polys=True) + xi = [] + w = [] + for r in p.real_roots(): + if isinstance(r, RootOf): + r = r.eval_rational(S.One/10**(n_digits+2)) + xi.append(r.n(n_digits)) + w.append(( + - (2*n+alpha+beta+2) / (n+alpha+beta+S.One) * + (gamma(n+alpha+1)*gamma(n+beta+1)) / + (gamma(n+alpha+beta+S.One)*gamma(n+2)) * + 2**(alpha+beta) / (pd.subs(x, r) * pn.subs(x, r))).n(n_digits)) + return xi, w + + +def gauss_lobatto(n, n_digits): + r""" + Computes the Gauss-Lobatto quadrature [1]_ points and weights. + + Explanation + =========== + + The Gauss-Lobatto quadrature approximates the integral: + + .. math:: + \int_{-1}^1 f(x)\,dx \approx \sum_{i=1}^n w_i f(x_i) + + The nodes `x_i` of an order `n` quadrature rule are the roots of `P'_(n-1)` + and the weights `w_i` are given by: + + .. math:: + &w_i = \frac{2}{n(n-1) \left[P_{n-1}(x_i)\right]^2},\quad x\neq\pm 1\\ + &w_i = \frac{2}{n(n-1)},\quad x=\pm 1 + + Parameters + ========== + + n : the order of quadrature + + n_digits : number of significant digits of the points and weights to return + + Returns + ======= + + (x, w) : the ``x`` and ``w`` are lists of points and weights as Floats. + The points `x_i` and weights `w_i` are returned as ``(x, w)`` + tuple of lists. + + Examples + ======== + + >>> from sympy.integrals.quadrature import gauss_lobatto + >>> x, w = gauss_lobatto(3, 5) + >>> x + [-1, 0, 1] + >>> w + [0.33333, 1.3333, 0.33333] + >>> x, w = gauss_lobatto(4, 5) + >>> x + [-1, -0.44721, 0.44721, 1] + >>> w + [0.16667, 0.83333, 0.83333, 0.16667] + + See Also + ======== + + gauss_legendre,gauss_laguerre, gauss_gen_laguerre, gauss_hermite, gauss_chebyshev_t, gauss_chebyshev_u, gauss_jacobi + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Gaussian_quadrature#Gauss.E2.80.93Lobatto_rules + .. [2] https://web.archive.org/web/20200118141346/http://people.math.sfu.ca/~cbm/aands/page_888.htm + """ + x = Dummy("x") + p = legendre_poly(n-1, x, polys=True) + pd = p.diff(x) + xi = [] + w = [] + for r in pd.real_roots(): + if isinstance(r, RootOf): + r = r.eval_rational(S.One/10**(n_digits+2)) + xi.append(r.n(n_digits)) + w.append((2/(n*(n-1) * p.subs(x, r)**2)).n(n_digits)) + + xi.insert(0, -1) + xi.append(1) + w.insert(0, (S(2)/(n*(n-1))).n(n_digits)) + w.append((S(2)/(n*(n-1))).n(n_digits)) + return xi, w diff --git a/.venv/lib/python3.13/site-packages/sympy/integrals/rationaltools.py b/.venv/lib/python3.13/site-packages/sympy/integrals/rationaltools.py new file mode 100644 index 0000000000000000000000000000000000000000..e95ff5da2e9d1be6f07d8fe6e9c572f692e92efb --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/integrals/rationaltools.py @@ -0,0 +1,445 @@ +"""This module implements tools for integrating rational functions. """ + +from sympy.core.function import Lambda +from sympy.core.numbers import I +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.trigonometric import atan +from sympy.polys.polyerrors import DomainError +from sympy.polys.polyroots import roots +from sympy.polys.polytools import cancel +from sympy.polys.rootoftools import RootSum +from sympy.polys import Poly, resultant, ZZ + + +def ratint(f, x, **flags): + """ + Performs indefinite integration of rational functions. + + Explanation + =========== + + Given a field :math:`K` and a rational function :math:`f = p/q`, + where :math:`p` and :math:`q` are polynomials in :math:`K[x]`, + returns a function :math:`g` such that :math:`f = g'`. + + Examples + ======== + + >>> from sympy.integrals.rationaltools import ratint + >>> from sympy.abc import x + + >>> ratint(36/(x**5 - 2*x**4 - 2*x**3 + 4*x**2 + x - 2), x) + (12*x + 6)/(x**2 - 1) + 4*log(x - 2) - 4*log(x + 1) + + References + ========== + + .. [1] M. Bronstein, Symbolic Integration I: Transcendental + Functions, Second Edition, Springer-Verlag, 2005, pp. 35-70 + + See Also + ======== + + sympy.integrals.integrals.Integral.doit + sympy.integrals.rationaltools.ratint_logpart + sympy.integrals.rationaltools.ratint_ratpart + + """ + if isinstance(f, tuple): + p, q = f + else: + p, q = f.as_numer_denom() + + p, q = Poly(p, x, composite=False, field=True), Poly(q, x, composite=False, field=True) + + coeff, p, q = p.cancel(q) + poly, p = p.div(q) + + result = poly.integrate(x).as_expr() + + if p.is_zero: + return coeff*result + + g, h = ratint_ratpart(p, q, x) + + P, Q = h.as_numer_denom() + + P = Poly(P, x) + Q = Poly(Q, x) + + q, r = P.div(Q) + + result += g + q.integrate(x).as_expr() + + if not r.is_zero: + symbol = flags.get('symbol', 't') + + if not isinstance(symbol, Symbol): + t = Dummy(symbol) + else: + t = symbol.as_dummy() + + L = ratint_logpart(r, Q, x, t) + + real = flags.get('real') + + if real is None: + if isinstance(f, tuple): + p, q = f + atoms = p.atoms() | q.atoms() + else: + atoms = f.atoms() + + for elt in atoms - {x}: + if not elt.is_extended_real: + real = False + break + else: + real = True + + eps = S.Zero + + if not real: + for h, q in L: + _, h = h.primitive() + eps += RootSum( + q, Lambda(t, t*log(h.as_expr())), quadratic=True) + else: + for h, q in L: + _, h = h.primitive() + R = log_to_real(h, q, x, t) + + if R is not None: + eps += R + else: + eps += RootSum( + q, Lambda(t, t*log(h.as_expr())), quadratic=True) + + result += eps + + return coeff*result + + +def ratint_ratpart(f, g, x): + """ + Horowitz-Ostrogradsky algorithm. + + Explanation + =========== + + Given a field K and polynomials f and g in K[x], such that f and g + are coprime and deg(f) < deg(g), returns fractions A and B in K(x), + such that f/g = A' + B and B has square-free denominator. + + Examples + ======== + + >>> from sympy.integrals.rationaltools import ratint_ratpart + >>> from sympy.abc import x, y + >>> from sympy import Poly + >>> ratint_ratpart(Poly(1, x, domain='ZZ'), + ... Poly(x + 1, x, domain='ZZ'), x) + (0, 1/(x + 1)) + >>> ratint_ratpart(Poly(1, x, domain='EX'), + ... Poly(x**2 + y**2, x, domain='EX'), x) + (0, 1/(x**2 + y**2)) + >>> ratint_ratpart(Poly(36, x, domain='ZZ'), + ... Poly(x**5 - 2*x**4 - 2*x**3 + 4*x**2 + x - 2, x, domain='ZZ'), x) + ((12*x + 6)/(x**2 - 1), 12/(x**2 - x - 2)) + + See Also + ======== + + ratint, ratint_logpart + """ + from sympy.solvers.solvers import solve + + f = Poly(f, x) + g = Poly(g, x) + + u, v, _ = g.cofactors(g.diff()) + + n = u.degree() + m = v.degree() + + A_coeffs = [ Dummy('a' + str(n - i)) for i in range(0, n) ] + B_coeffs = [ Dummy('b' + str(m - i)) for i in range(0, m) ] + + C_coeffs = A_coeffs + B_coeffs + + A = Poly(A_coeffs, x, domain=ZZ[C_coeffs]) + B = Poly(B_coeffs, x, domain=ZZ[C_coeffs]) + + H = f - A.diff()*v + A*(u.diff()*v).quo(u) - B*u + + result = solve(H.coeffs(), C_coeffs) + + A = A.as_expr().subs(result) + B = B.as_expr().subs(result) + + rat_part = cancel(A/u.as_expr(), x) + log_part = cancel(B/v.as_expr(), x) + + return rat_part, log_part + + +def ratint_logpart(f, g, x, t=None): + r""" + Lazard-Rioboo-Trager algorithm. + + Explanation + =========== + + Given a field K and polynomials f and g in K[x], such that f and g + are coprime, deg(f) < deg(g) and g is square-free, returns a list + of tuples (s_i, q_i) of polynomials, for i = 1..n, such that s_i + in K[t, x] and q_i in K[t], and:: + + ___ ___ + d f d \ ` \ ` + -- - = -- ) ) a log(s_i(a, x)) + dx g dx /__, /__, + i=1..n a | q_i(a) = 0 + + Examples + ======== + + >>> from sympy.integrals.rationaltools import ratint_logpart + >>> from sympy.abc import x + >>> from sympy import Poly + >>> ratint_logpart(Poly(1, x, domain='ZZ'), + ... Poly(x**2 + x + 1, x, domain='ZZ'), x) + [(Poly(x + 3*_t/2 + 1/2, x, domain='QQ[_t]'), + ...Poly(3*_t**2 + 1, _t, domain='ZZ'))] + >>> ratint_logpart(Poly(12, x, domain='ZZ'), + ... Poly(x**2 - x - 2, x, domain='ZZ'), x) + [(Poly(x - 3*_t/8 - 1/2, x, domain='QQ[_t]'), + ...Poly(-_t**2 + 16, _t, domain='ZZ'))] + + See Also + ======== + + ratint, ratint_ratpart + """ + f, g = Poly(f, x), Poly(g, x) + + t = t or Dummy('t') + a, b = g, f - g.diff()*Poly(t, x) + + res, R = resultant(a, b, includePRS=True) + res = Poly(res, t, composite=False) + + assert res, "BUG: resultant(%s, %s) cannot be zero" % (a, b) + + R_map, H = {}, [] + + for r in R: + R_map[r.degree()] = r + + def _include_sign(c, sqf): + if c.is_extended_real and (c < 0) == True: + h, k = sqf[0] + c_poly = c.as_poly(h.gens) + sqf[0] = h*c_poly, k + + C, res_sqf = res.sqf_list() + _include_sign(C, res_sqf) + + for q, i in res_sqf: + _, q = q.primitive() + + if g.degree() == i: + H.append((g, q)) + else: + h = R_map[i] + h_lc = Poly(h.LC(), t, field=True) + + c, h_lc_sqf = h_lc.sqf_list(all=True) + _include_sign(c, h_lc_sqf) + + for a, j in h_lc_sqf: + h = h.quo(Poly(a.gcd(q)**j, x)) + + inv, coeffs = h_lc.invert(q), [S.One] + + for coeff in h.coeffs()[1:]: + coeff = coeff.as_poly(inv.gens) + T = (inv*coeff).rem(q) + coeffs.append(T.as_expr()) + + h = Poly(dict(list(zip(h.monoms(), coeffs))), x) + + H.append((h, q)) + + return H + + +def log_to_atan(f, g): + """ + Convert complex logarithms to real arctangents. + + Explanation + =========== + + Given a real field K and polynomials f and g in K[x], with g != 0, + returns a sum h of arctangents of polynomials in K[x], such that: + + dh d f + I g + -- = -- I log( ------- ) + dx dx f - I g + + Examples + ======== + + >>> from sympy.integrals.rationaltools import log_to_atan + >>> from sympy.abc import x + >>> from sympy import Poly, sqrt, S + >>> log_to_atan(Poly(x, x, domain='ZZ'), Poly(1, x, domain='ZZ')) + 2*atan(x) + >>> log_to_atan(Poly(x + S(1)/2, x, domain='QQ'), + ... Poly(sqrt(3)/2, x, domain='EX')) + 2*atan(2*sqrt(3)*x/3 + sqrt(3)/3) + + See Also + ======== + + log_to_real + """ + if f.degree() < g.degree(): + f, g = -g, f + + f = f.to_field() + g = g.to_field() + + p, q = f.div(g) + + if q.is_zero: + return 2*atan(p.as_expr()) + else: + s, t, h = g.gcdex(-f) + u = (f*s + g*t).quo(h) + A = 2*atan(u.as_expr()) + + return A + log_to_atan(s, t) + + +def _get_real_roots(f, x): + """get real roots of f if possible""" + rs = roots(f, filter='R') + + try: + num_roots = f.count_roots() + except DomainError: + return rs + else: + if len(rs) == num_roots: + return rs + else: + return None + + +def log_to_real(h, q, x, t): + r""" + Convert complex logarithms to real functions. + + Explanation + =========== + + Given real field K and polynomials h in K[t,x] and q in K[t], + returns real function f such that: + ___ + df d \ ` + -- = -- ) a log(h(a, x)) + dx dx /__, + a | q(a) = 0 + + Examples + ======== + + >>> from sympy.integrals.rationaltools import log_to_real + >>> from sympy.abc import x, y + >>> from sympy import Poly, S + >>> log_to_real(Poly(x + 3*y/2 + S(1)/2, x, domain='QQ[y]'), + ... Poly(3*y**2 + 1, y, domain='ZZ'), x, y) + 2*sqrt(3)*atan(2*sqrt(3)*x/3 + sqrt(3)/3)/3 + >>> log_to_real(Poly(x**2 - 1, x, domain='ZZ'), + ... Poly(-2*y + 1, y, domain='ZZ'), x, y) + log(x**2 - 1)/2 + + See Also + ======== + + log_to_atan + """ + from sympy.simplify.radsimp import collect + u, v = symbols('u,v', cls=Dummy) + + H = h.as_expr().xreplace({t: u + I*v}).expand() + Q = q.as_expr().xreplace({t: u + I*v}).expand() + + H_map = collect(H, I, evaluate=False) + Q_map = collect(Q, I, evaluate=False) + + a, b = H_map.get(S.One, S.Zero), H_map.get(I, S.Zero) + c, d = Q_map.get(S.One, S.Zero), Q_map.get(I, S.Zero) + + R = Poly(resultant(c, d, v), u) + + R_u = _get_real_roots(R, u) + + if R_u is None: + return None + + result = S.Zero + + for r_u in R_u.keys(): + C = Poly(c.xreplace({u: r_u}), v) + if not C: + # t was split into real and imaginary parts + # and denom Q(u, v) = c + I*d. We just found + # that c(r_u) is 0 so the roots are in d + C = Poly(d.xreplace({u: r_u}), v) + # we were going to reject roots from C that + # did not set d to zero, but since we are now + # using C = d and c is already 0, there is + # nothing to check + d = S.Zero + + R_v = _get_real_roots(C, v) + + if R_v is None: + return None + + R_v_paired = [] # take one from each pair of conjugate roots + for r_v in R_v: + if r_v not in R_v_paired and -r_v not in R_v_paired: + if r_v.is_negative or r_v.could_extract_minus_sign(): + R_v_paired.append(-r_v) + elif not r_v.is_zero: + R_v_paired.append(r_v) + + for r_v in R_v_paired: + + D = d.xreplace({u: r_u, v: r_v}) + + if D.evalf(chop=True) != 0: + continue + + A = Poly(a.xreplace({u: r_u, v: r_v}), x) + B = Poly(b.xreplace({u: r_u, v: r_v}), x) + + AB = (A**2 + B**2).as_expr() + + result += r_u*log(AB) + r_v*log_to_atan(A, B) + + R_q = _get_real_roots(q, t) + + if R_q is None: + return None + + for r in R_q.keys(): + result += r*log(h.as_expr().subs(t, r)) + + return result diff --git a/.venv/lib/python3.13/site-packages/sympy/integrals/rde.py b/.venv/lib/python3.13/site-packages/sympy/integrals/rde.py new file mode 100644 index 0000000000000000000000000000000000000000..9fb14a1d14b743b1e0885bf25a0d4b409e8a610d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/integrals/rde.py @@ -0,0 +1,800 @@ +""" +Algorithms for solving the Risch differential equation. + +Given a differential field K of characteristic 0 that is a simple +monomial extension of a base field k and f, g in K, the Risch +Differential Equation problem is to decide if there exist y in K such +that Dy + f*y == g and to find one if there are some. If t is a +monomial over k and the coefficients of f and g are in k(t), then y is +in k(t), and the outline of the algorithm here is given as: + +1. Compute the normal part n of the denominator of y. The problem is +then reduced to finding y' in k, where y == y'/n. +2. Compute the special part s of the denominator of y. The problem is +then reduced to finding y'' in k[t], where y == y''/(n*s) +3. Bound the degree of y''. +4. Reduce the equation Dy + f*y == g to a similar equation with f, g in +k[t]. +5. Find the solutions in k[t] of bounded degree of the reduced equation. + +See Chapter 6 of "Symbolic Integration I: Transcendental Functions" by +Manuel Bronstein. See also the docstring of risch.py. +""" + +from operator import mul +from functools import reduce + +from sympy.core import oo +from sympy.core.symbol import Dummy + +from sympy.polys import Poly, gcd, ZZ, cancel + +from sympy.functions.elementary.complexes import (im, re) +from sympy.functions.elementary.miscellaneous import sqrt + +from sympy.integrals.risch import (gcdex_diophantine, frac_in, derivation, + splitfactor, NonElementaryIntegralException, DecrementLevel, recognize_log_derivative) + +# TODO: Add messages to NonElementaryIntegralException errors + + +def order_at(a, p, t): + """ + Computes the order of a at p, with respect to t. + + Explanation + =========== + + For a, p in k[t], the order of a at p is defined as nu_p(a) = max({n + in Z+ such that p**n|a}), where a != 0. If a == 0, nu_p(a) = +oo. + + To compute the order at a rational function, a/b, use the fact that + nu_p(a/b) == nu_p(a) - nu_p(b). + """ + if a.is_zero: + return oo + if p == Poly(t, t): + return a.as_poly(t).ET()[0][0] + + # Uses binary search for calculating the power. power_list collects the tuples + # (p^k,k) where each k is some power of 2. After deciding the largest k + # such that k is power of 2 and p^k|a the loop iteratively calculates + # the actual power. + power_list = [] + p1 = p + r = a.rem(p1) + tracks_power = 1 + while r.is_zero: + power_list.append((p1,tracks_power)) + p1 = p1*p1 + tracks_power *= 2 + r = a.rem(p1) + n = 0 + product = Poly(1, t) + while len(power_list) != 0: + final = power_list.pop() + productf = product*final[0] + r = a.rem(productf) + if r.is_zero: + n += final[1] + product = productf + return n + + +def order_at_oo(a, d, t): + """ + Computes the order of a/d at oo (infinity), with respect to t. + + For f in k(t), the order or f at oo is defined as deg(d) - deg(a), where + f == a/d. + """ + if a.is_zero: + return oo + return d.degree(t) - a.degree(t) + + +def weak_normalizer(a, d, DE, z=None): + """ + Weak normalization. + + Explanation + =========== + + Given a derivation D on k[t] and f == a/d in k(t), return q in k[t] + such that f - Dq/q is weakly normalized with respect to t. + + f in k(t) is said to be "weakly normalized" with respect to t if + residue_p(f) is not a positive integer for any normal irreducible p + in k[t] such that f is in R_p (Definition 6.1.1). If f has an + elementary integral, this is equivalent to no logarithm of + integral(f) whose argument depends on t has a positive integer + coefficient, where the arguments of the logarithms not in k(t) are + in k[t]. + + Returns (q, f - Dq/q) + """ + z = z or Dummy('z') + dn, ds = splitfactor(d, DE) + + # Compute d1, where dn == d1*d2**2*...*dn**n is a square-free + # factorization of d. + g = gcd(dn, dn.diff(DE.t)) + d_sqf_part = dn.quo(g) + d1 = d_sqf_part.quo(gcd(d_sqf_part, g)) + + a1, b = gcdex_diophantine(d.quo(d1).as_poly(DE.t), d1.as_poly(DE.t), + a.as_poly(DE.t)) + r = (a - Poly(z, DE.t)*derivation(d1, DE)).as_poly(DE.t).resultant( + d1.as_poly(DE.t)) + r = Poly(r, z) + + if not r.expr.has(z): + return (Poly(1, DE.t), (a, d)) + + N = [i for i in r.real_roots() if i in ZZ and i > 0] + + q = reduce(mul, [gcd(a - Poly(n, DE.t)*derivation(d1, DE), d1) for n in N], + Poly(1, DE.t)) + + dq = derivation(q, DE) + sn = q*a - d*dq + sd = q*d + sn, sd = sn.cancel(sd, include=True) + + return (q, (sn, sd)) + + +def normal_denom(fa, fd, ga, gd, DE): + """ + Normal part of the denominator. + + Explanation + =========== + + Given a derivation D on k[t] and f, g in k(t) with f weakly + normalized with respect to t, either raise NonElementaryIntegralException, + in which case the equation Dy + f*y == g has no solution in k(t), or the + quadruplet (a, b, c, h) such that a, h in k[t], b, c in k, and for any + solution y in k(t) of Dy + f*y == g, q = y*h in k satisfies + a*Dq + b*q == c. + + This constitutes step 1 in the outline given in the rde.py docstring. + """ + dn, ds = splitfactor(fd, DE) + en, es = splitfactor(gd, DE) + + p = dn.gcd(en) + h = en.gcd(en.diff(DE.t)).quo(p.gcd(p.diff(DE.t))) + + a = dn*h + c = a*h + if c.div(en)[1]: + # en does not divide dn*h**2 + raise NonElementaryIntegralException + ca = c*ga + ca, cd = ca.cancel(gd, include=True) + + ba = a*fa - dn*derivation(h, DE)*fd + ba, bd = ba.cancel(fd, include=True) + + # (dn*h, dn*h*f - dn*Dh, dn*h**2*g, h) + return (a, (ba, bd), (ca, cd), h) + + +def special_denom(a, ba, bd, ca, cd, DE, case='auto'): + """ + Special part of the denominator. + + Explanation + =========== + + case is one of {'exp', 'tan', 'primitive'} for the hyperexponential, + hypertangent, and primitive cases, respectively. For the + hyperexponential (resp. hypertangent) case, given a derivation D on + k[t] and a in k[t], b, c, in k with Dt/t in k (resp. Dt/(t**2 + 1) in + k, sqrt(-1) not in k), a != 0, and gcd(a, t) == 1 (resp. + gcd(a, t**2 + 1) == 1), return the quadruplet (A, B, C, 1/h) such that + A, B, C, h in k[t] and for any solution q in k of a*Dq + b*q == c, + r = qh in k[t] satisfies A*Dr + B*r == C. + + For ``case == 'primitive'``, k == k[t], so it returns (a, b, c, 1) in + this case. + + This constitutes step 2 of the outline given in the rde.py docstring. + """ + # TODO: finish writing this and write tests + + if case == 'auto': + case = DE.case + + if case == 'exp': + p = Poly(DE.t, DE.t) + elif case == 'tan': + p = Poly(DE.t**2 + 1, DE.t) + elif case in ('primitive', 'base'): + B = ba.to_field().quo(bd) + C = ca.to_field().quo(cd) + return (a, B, C, Poly(1, DE.t)) + else: + raise ValueError("case must be one of {'exp', 'tan', 'primitive', " + "'base'}, not %s." % case) + + nb = order_at(ba, p, DE.t) - order_at(bd, p, DE.t) + nc = order_at(ca, p, DE.t) - order_at(cd, p, DE.t) + + n = min(0, nc - min(0, nb)) + if not nb: + # Possible cancellation. + from .prde import parametric_log_deriv + if case == 'exp': + dcoeff = DE.d.quo(Poly(DE.t, DE.t)) + with DecrementLevel(DE): # We are guaranteed to not have problems, + # because case != 'base'. + alphaa, alphad = frac_in(-ba.eval(0)/bd.eval(0)/a.eval(0), DE.t) + etaa, etad = frac_in(dcoeff, DE.t) + A = parametric_log_deriv(alphaa, alphad, etaa, etad, DE) + if A is not None: + Q, m, z = A + if Q == 1: + n = min(n, m) + + elif case == 'tan': + dcoeff = DE.d.quo(Poly(DE.t**2+1, DE.t)) + with DecrementLevel(DE): # We are guaranteed to not have problems, + # because case != 'base'. + alphaa, alphad = frac_in(im(-ba.eval(sqrt(-1))/bd.eval(sqrt(-1))/a.eval(sqrt(-1))), DE.t) + betaa, betad = frac_in(re(-ba.eval(sqrt(-1))/bd.eval(sqrt(-1))/a.eval(sqrt(-1))), DE.t) + etaa, etad = frac_in(dcoeff, DE.t) + + if recognize_log_derivative(Poly(2, DE.t)*betaa, betad, DE): + A = parametric_log_deriv(alphaa*Poly(sqrt(-1), DE.t)*betad+alphad*betaa, alphad*betad, etaa, etad, DE) + if A is not None: + Q, m, z = A + if Q == 1: + n = min(n, m) + N = max(0, -nb, n - nc) + pN = p**N + pn = p**-n + + A = a*pN + B = ba*pN.quo(bd) + Poly(n, DE.t)*a*derivation(p, DE).quo(p)*pN + C = (ca*pN*pn).quo(cd) + h = pn + + # (a*p**N, (b + n*a*Dp/p)*p**N, c*p**(N - n), p**-n) + return (A, B, C, h) + + +def bound_degree(a, b, cQ, DE, case='auto', parametric=False): + """ + Bound on polynomial solutions. + + Explanation + =========== + + Given a derivation D on k[t] and ``a``, ``b``, ``c`` in k[t] with ``a != 0``, return + n in ZZ such that deg(q) <= n for any solution q in k[t] of + a*Dq + b*q == c, when parametric=False, or deg(q) <= n for any solution + c1, ..., cm in Const(k) and q in k[t] of a*Dq + b*q == Sum(ci*gi, (i, 1, m)) + when parametric=True. + + For ``parametric=False``, ``cQ`` is ``c``, a ``Poly``; for ``parametric=True``, ``cQ`` is Q == + [q1, ..., qm], a list of Polys. + + This constitutes step 3 of the outline given in the rde.py docstring. + """ + # TODO: finish writing this and write tests + + if case == 'auto': + case = DE.case + + da = a.degree(DE.t) + db = b.degree(DE.t) + + # The parametric and regular cases are identical, except for this part + if parametric: + dc = max(i.degree(DE.t) for i in cQ) + else: + dc = cQ.degree(DE.t) + + alpha = cancel(-b.as_poly(DE.t).LC().as_expr()/ + a.as_poly(DE.t).LC().as_expr()) + + if case == 'base': + n = max(0, dc - max(db, da - 1)) + if db == da - 1 and alpha.is_Integer: + n = max(0, alpha, dc - db) + + elif case == 'primitive': + if db > da: + n = max(0, dc - db) + else: + n = max(0, dc - da + 1) + + etaa, etad = frac_in(DE.d, DE.T[DE.level - 1]) + + t1 = DE.t + with DecrementLevel(DE): + alphaa, alphad = frac_in(alpha, DE.t) + if db == da - 1: + from .prde import limited_integrate + # if alpha == m*Dt + Dz for z in k and m in ZZ: + try: + (za, zd), m = limited_integrate(alphaa, alphad, [(etaa, etad)], + DE) + except NonElementaryIntegralException: + pass + else: + if len(m) != 1: + raise ValueError("Length of m should be 1") + n = max(n, m[0]) + + elif db == da: + # if alpha == Dz/z for z in k*: + # beta = -lc(a*Dz + b*z)/(z*lc(a)) + # if beta == m*Dt + Dw for w in k and m in ZZ: + # n = max(n, m) + from .prde import is_log_deriv_k_t_radical_in_field + A = is_log_deriv_k_t_radical_in_field(alphaa, alphad, DE) + if A is not None: + aa, z = A + if aa == 1: + beta = -(a*derivation(z, DE).as_poly(t1) + + b*z.as_poly(t1)).LC()/(z.as_expr()*a.LC()) + betaa, betad = frac_in(beta, DE.t) + from .prde import limited_integrate + try: + (za, zd), m = limited_integrate(betaa, betad, + [(etaa, etad)], DE) + except NonElementaryIntegralException: + pass + else: + if len(m) != 1: + raise ValueError("Length of m should be 1") + n = max(n, m[0].as_expr()) + + elif case == 'exp': + from .prde import parametric_log_deriv + + n = max(0, dc - max(db, da)) + if da == db: + etaa, etad = frac_in(DE.d.quo(Poly(DE.t, DE.t)), DE.T[DE.level - 1]) + with DecrementLevel(DE): + alphaa, alphad = frac_in(alpha, DE.t) + A = parametric_log_deriv(alphaa, alphad, etaa, etad, DE) + if A is not None: + # if alpha == m*Dt/t + Dz/z for z in k* and m in ZZ: + # n = max(n, m) + a, m, z = A + if a == 1: + n = max(n, m) + + elif case in ('tan', 'other_nonlinear'): + delta = DE.d.degree(DE.t) + lam = DE.d.LC() + alpha = cancel(alpha/lam) + n = max(0, dc - max(da + delta - 1, db)) + if db == da + delta - 1 and alpha.is_Integer: + n = max(0, alpha, dc - db) + + else: + raise ValueError("case must be one of {'exp', 'tan', 'primitive', " + "'other_nonlinear', 'base'}, not %s." % case) + + return n + + +def spde(a, b, c, n, DE): + """ + Rothstein's Special Polynomial Differential Equation algorithm. + + Explanation + =========== + + Given a derivation D on k[t], an integer n and ``a``,``b``,``c`` in k[t] with + ``a != 0``, either raise NonElementaryIntegralException, in which case the + equation a*Dq + b*q == c has no solution of degree at most ``n`` in + k[t], or return the tuple (B, C, m, alpha, beta) such that B, C, + alpha, beta in k[t], m in ZZ, and any solution q in k[t] of degree + at most n of a*Dq + b*q == c must be of the form + q == alpha*h + beta, where h in k[t], deg(h) <= m, and Dh + B*h == C. + + This constitutes step 4 of the outline given in the rde.py docstring. + """ + zero = Poly(0, DE.t) + + alpha = Poly(1, DE.t) + beta = Poly(0, DE.t) + + while True: + if c.is_zero: + return (zero, zero, 0, zero, beta) # -1 is more to the point + if (n < 0) is True: + raise NonElementaryIntegralException + + g = a.gcd(b) + if not c.rem(g).is_zero: # g does not divide c + raise NonElementaryIntegralException + + a, b, c = a.quo(g), b.quo(g), c.quo(g) + + if a.degree(DE.t) == 0: + b = b.to_field().quo(a) + c = c.to_field().quo(a) + return (b, c, n, alpha, beta) + + r, z = gcdex_diophantine(b, a, c) + b += derivation(a, DE) + c = z - derivation(r, DE) + n -= a.degree(DE.t) + + beta += alpha * r + alpha *= a + +def no_cancel_b_large(b, c, n, DE): + """ + Poly Risch Differential Equation - No cancellation: deg(b) large enough. + + Explanation + =========== + + Given a derivation D on k[t], ``n`` either an integer or +oo, and ``b``,``c`` + in k[t] with ``b != 0`` and either D == d/dt or + deg(b) > max(0, deg(D) - 1), either raise NonElementaryIntegralException, in + which case the equation ``Dq + b*q == c`` has no solution of degree at + most n in k[t], or a solution q in k[t] of this equation with + ``deg(q) < n``. + """ + q = Poly(0, DE.t) + + while not c.is_zero: + m = c.degree(DE.t) - b.degree(DE.t) + if not 0 <= m <= n: # n < 0 or m < 0 or m > n + raise NonElementaryIntegralException + + p = Poly(c.as_poly(DE.t).LC()/b.as_poly(DE.t).LC()*DE.t**m, DE.t, + expand=False) + q = q + p + n = m - 1 + c = c - derivation(p, DE) - b*p + + return q + + +def no_cancel_b_small(b, c, n, DE): + """ + Poly Risch Differential Equation - No cancellation: deg(b) small enough. + + Explanation + =========== + + Given a derivation D on k[t], ``n`` either an integer or +oo, and ``b``,``c`` + in k[t] with deg(b) < deg(D) - 1 and either D == d/dt or + deg(D) >= 2, either raise NonElementaryIntegralException, in which case the + equation Dq + b*q == c has no solution of degree at most n in k[t], + or a solution q in k[t] of this equation with deg(q) <= n, or the + tuple (h, b0, c0) such that h in k[t], b0, c0, in k, and for any + solution q in k[t] of degree at most n of Dq + bq == c, y == q - h + is a solution in k of Dy + b0*y == c0. + """ + q = Poly(0, DE.t) + + while not c.is_zero: + if n == 0: + m = 0 + else: + m = c.degree(DE.t) - DE.d.degree(DE.t) + 1 + + if not 0 <= m <= n: # n < 0 or m < 0 or m > n + raise NonElementaryIntegralException + + if m > 0: + p = Poly(c.as_poly(DE.t).LC()/(m*DE.d.as_poly(DE.t).LC())*DE.t**m, + DE.t, expand=False) + else: + if b.degree(DE.t) != c.degree(DE.t): + raise NonElementaryIntegralException + if b.degree(DE.t) == 0: + return (q, b.as_poly(DE.T[DE.level - 1]), + c.as_poly(DE.T[DE.level - 1])) + p = Poly(c.as_poly(DE.t).LC()/b.as_poly(DE.t).LC(), DE.t, + expand=False) + + q = q + p + n = m - 1 + c = c - derivation(p, DE) - b*p + + return q + + +# TODO: better name for this function +def no_cancel_equal(b, c, n, DE): + """ + Poly Risch Differential Equation - No cancellation: deg(b) == deg(D) - 1 + + Explanation + =========== + + Given a derivation D on k[t] with deg(D) >= 2, n either an integer + or +oo, and b, c in k[t] with deg(b) == deg(D) - 1, either raise + NonElementaryIntegralException, in which case the equation Dq + b*q == c has + no solution of degree at most n in k[t], or a solution q in k[t] of + this equation with deg(q) <= n, or the tuple (h, m, C) such that h + in k[t], m in ZZ, and C in k[t], and for any solution q in k[t] of + degree at most n of Dq + b*q == c, y == q - h is a solution in k[t] + of degree at most m of Dy + b*y == C. + """ + q = Poly(0, DE.t) + lc = cancel(-b.as_poly(DE.t).LC()/DE.d.as_poly(DE.t).LC()) + if lc.is_Integer and lc.is_positive: + M = lc + else: + M = -1 + + while not c.is_zero: + m = max(M, c.degree(DE.t) - DE.d.degree(DE.t) + 1) + + if not 0 <= m <= n: # n < 0 or m < 0 or m > n + raise NonElementaryIntegralException + + u = cancel(m*DE.d.as_poly(DE.t).LC() + b.as_poly(DE.t).LC()) + if u.is_zero: + return (q, m, c) + if m > 0: + p = Poly(c.as_poly(DE.t).LC()/u*DE.t**m, DE.t, expand=False) + else: + if c.degree(DE.t) != DE.d.degree(DE.t) - 1: + raise NonElementaryIntegralException + else: + p = c.as_poly(DE.t).LC()/b.as_poly(DE.t).LC() + + q = q + p + n = m - 1 + c = c - derivation(p, DE) - b*p + + return q + + +def cancel_primitive(b, c, n, DE): + """ + Poly Risch Differential Equation - Cancellation: Primitive case. + + Explanation + =========== + + Given a derivation D on k[t], n either an integer or +oo, ``b`` in k, and + ``c`` in k[t] with Dt in k and ``b != 0``, either raise + NonElementaryIntegralException, in which case the equation Dq + b*q == c + has no solution of degree at most n in k[t], or a solution q in k[t] of + this equation with deg(q) <= n. + """ + # Delayed imports + from .prde import is_log_deriv_k_t_radical_in_field + with DecrementLevel(DE): + ba, bd = frac_in(b, DE.t) + A = is_log_deriv_k_t_radical_in_field(ba, bd, DE) + if A is not None: + n, z = A + if n == 1: # b == Dz/z + raise NotImplementedError("is_deriv_in_field() is required to " + " solve this problem.") + # if z*c == Dp for p in k[t] and deg(p) <= n: + # return p/z + # else: + # raise NonElementaryIntegralException + + if c.is_zero: + return c # return 0 + + if n < c.degree(DE.t): + raise NonElementaryIntegralException + + q = Poly(0, DE.t) + while not c.is_zero: + m = c.degree(DE.t) + if n < m: + raise NonElementaryIntegralException + with DecrementLevel(DE): + a2a, a2d = frac_in(c.LC(), DE.t) + sa, sd = rischDE(ba, bd, a2a, a2d, DE) + stm = Poly(sa.as_expr()/sd.as_expr()*DE.t**m, DE.t, expand=False) + q += stm + n = m - 1 + c -= b*stm + derivation(stm, DE) + + return q + + +def cancel_exp(b, c, n, DE): + """ + Poly Risch Differential Equation - Cancellation: Hyperexponential case. + + Explanation + =========== + + Given a derivation D on k[t], n either an integer or +oo, ``b`` in k, and + ``c`` in k[t] with Dt/t in k and ``b != 0``, either raise + NonElementaryIntegralException, in which case the equation Dq + b*q == c + has no solution of degree at most n in k[t], or a solution q in k[t] of + this equation with deg(q) <= n. + """ + from .prde import parametric_log_deriv + eta = DE.d.quo(Poly(DE.t, DE.t)).as_expr() + + with DecrementLevel(DE): + etaa, etad = frac_in(eta, DE.t) + ba, bd = frac_in(b, DE.t) + A = parametric_log_deriv(ba, bd, etaa, etad, DE) + if A is not None: + a, m, z = A + if a == 1: + raise NotImplementedError("is_deriv_in_field() is required to " + "solve this problem.") + # if c*z*t**m == Dp for p in k and q = p/(z*t**m) in k[t] and + # deg(q) <= n: + # return q + # else: + # raise NonElementaryIntegralException + + if c.is_zero: + return c # return 0 + + if n < c.degree(DE.t): + raise NonElementaryIntegralException + + q = Poly(0, DE.t) + while not c.is_zero: + m = c.degree(DE.t) + if n < m: + raise NonElementaryIntegralException + # a1 = b + m*Dt/t + a1 = b.as_expr() + with DecrementLevel(DE): + # TODO: Write a dummy function that does this idiom + a1a, a1d = frac_in(a1, DE.t) + a1a = a1a*etad + etaa*a1d*Poly(m, DE.t) + a1d = a1d*etad + + a2a, a2d = frac_in(c.LC(), DE.t) + + sa, sd = rischDE(a1a, a1d, a2a, a2d, DE) + stm = Poly(sa.as_expr()/sd.as_expr()*DE.t**m, DE.t, expand=False) + q += stm + n = m - 1 + c -= b*stm + derivation(stm, DE) # deg(c) becomes smaller + return q + + +def solve_poly_rde(b, cQ, n, DE, parametric=False): + """ + Solve a Polynomial Risch Differential Equation with degree bound ``n``. + + This constitutes step 4 of the outline given in the rde.py docstring. + + For parametric=False, cQ is c, a Poly; for parametric=True, cQ is Q == + [q1, ..., qm], a list of Polys. + """ + # No cancellation + if not b.is_zero and (DE.case == 'base' or + b.degree(DE.t) > max(0, DE.d.degree(DE.t) - 1)): + + if parametric: + # Delayed imports + from .prde import prde_no_cancel_b_large + return prde_no_cancel_b_large(b, cQ, n, DE) + return no_cancel_b_large(b, cQ, n, DE) + + elif (b.is_zero or b.degree(DE.t) < DE.d.degree(DE.t) - 1) and \ + (DE.case == 'base' or DE.d.degree(DE.t) >= 2): + + if parametric: + from .prde import prde_no_cancel_b_small + return prde_no_cancel_b_small(b, cQ, n, DE) + + R = no_cancel_b_small(b, cQ, n, DE) + + if isinstance(R, Poly): + return R + else: + # XXX: Might k be a field? (pg. 209) + h, b0, c0 = R + with DecrementLevel(DE): + b0, c0 = b0.as_poly(DE.t), c0.as_poly(DE.t) + if b0 is None: # See above comment + raise ValueError("b0 should be a non-Null value") + if c0 is None: + raise ValueError("c0 should be a non-Null value") + y = solve_poly_rde(b0, c0, n, DE).as_poly(DE.t) + return h + y + + elif DE.d.degree(DE.t) >= 2 and b.degree(DE.t) == DE.d.degree(DE.t) - 1 and \ + n > -b.as_poly(DE.t).LC()/DE.d.as_poly(DE.t).LC(): + + # TODO: Is this check necessary, and if so, what should it do if it fails? + # b comes from the first element returned from spde() + if not b.as_poly(DE.t).LC().is_number: + raise TypeError("Result should be a number") + + if parametric: + raise NotImplementedError("prde_no_cancel_b_equal() is not yet " + "implemented.") + + R = no_cancel_equal(b, cQ, n, DE) + + if isinstance(R, Poly): + return R + else: + h, m, C = R + # XXX: Or should it be rischDE()? + y = solve_poly_rde(b, C, m, DE) + return h + y + + else: + # Cancellation + if b.is_zero: + raise NotImplementedError("Remaining cases for Poly (P)RDE are " + "not yet implemented (is_deriv_in_field() required).") + else: + if DE.case == 'exp': + if parametric: + raise NotImplementedError("Parametric RDE cancellation " + "hyperexponential case is not yet implemented.") + return cancel_exp(b, cQ, n, DE) + + elif DE.case == 'primitive': + if parametric: + raise NotImplementedError("Parametric RDE cancellation " + "primitive case is not yet implemented.") + return cancel_primitive(b, cQ, n, DE) + + else: + raise NotImplementedError("Other Poly (P)RDE cancellation " + "cases are not yet implemented (%s)." % DE.case) + + if parametric: + raise NotImplementedError("Remaining cases for Poly PRDE not yet " + "implemented.") + raise NotImplementedError("Remaining cases for Poly RDE not yet " + "implemented.") + + +def rischDE(fa, fd, ga, gd, DE): + """ + Solve a Risch Differential Equation: Dy + f*y == g. + + Explanation + =========== + + See the outline in the docstring of rde.py for more information + about the procedure used. Either raise NonElementaryIntegralException, in + which case there is no solution y in the given differential field, + or return y in k(t) satisfying Dy + f*y == g, or raise + NotImplementedError, in which case, the algorithms necessary to + solve the given Risch Differential Equation have not yet been + implemented. + """ + _, (fa, fd) = weak_normalizer(fa, fd, DE) + a, (ba, bd), (ca, cd), hn = normal_denom(fa, fd, ga, gd, DE) + A, B, C, hs = special_denom(a, ba, bd, ca, cd, DE) + try: + # Until this is fully implemented, use oo. Note that this will almost + # certainly cause non-termination in spde() (unless A == 1), and + # *might* lead to non-termination in the next step for a nonelementary + # integral (I don't know for certain yet). Fortunately, spde() is + # currently written recursively, so this will just give + # RuntimeError: maximum recursion depth exceeded. + n = bound_degree(A, B, C, DE) + except NotImplementedError: + # Useful for debugging: + # import warnings + # warnings.warn("rischDE: Proceeding with n = oo; may cause " + # "non-termination.") + n = oo + + B, C, m, alpha, beta = spde(A, B, C, n, DE) + if C.is_zero: + y = C + else: + y = solve_poly_rde(B, C, m, DE) + + return (alpha*y + beta, hn*hs) diff --git a/.venv/lib/python3.13/site-packages/sympy/integrals/risch.py b/.venv/lib/python3.13/site-packages/sympy/integrals/risch.py new file mode 100644 index 0000000000000000000000000000000000000000..89e5f10bbb1d011d98a5884ce74ab25b615e1c51 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/integrals/risch.py @@ -0,0 +1,1851 @@ +""" +The Risch Algorithm for transcendental function integration. + +The core algorithms for the Risch algorithm are here. The subproblem +algorithms are in the rde.py and prde.py files for the Risch +Differential Equation solver and the parametric problems solvers, +respectively. All important information concerning the differential extension +for an integrand is stored in a DifferentialExtension object, which in the code +is usually called DE. Throughout the code and Inside the DifferentialExtension +object, the conventions/attribute names are that the base domain is QQ and each +differential extension is x, t0, t1, ..., tn-1 = DE.t. DE.x is the variable of +integration (Dx == 1), DE.D is a list of the derivatives of +x, t1, t2, ..., tn-1 = t, DE.T is the list [x, t1, t2, ..., tn-1], DE.t is the +outer-most variable of the differential extension at the given level (the level +can be adjusted using DE.increment_level() and DE.decrement_level()), +k is the field C(x, t0, ..., tn-2), where C is the constant field. The +numerator of a fraction is denoted by a and the denominator by +d. If the fraction is named f, fa == numer(f) and fd == denom(f). +Fractions are returned as tuples (fa, fd). DE.d and DE.t are used to +represent the topmost derivation and extension variable, respectively. +The docstring of a function signifies whether an argument is in k[t], in +which case it will just return a Poly in t, or in k(t), in which case it +will return the fraction (fa, fd). Other variable names probably come +from the names used in Bronstein's book. +""" +from types import GeneratorType +from functools import reduce + +from sympy.core.function import Lambda +from sympy.core.mul import Mul +from sympy.core.intfunc import ilcm +from sympy.core.numbers import I +from sympy.core.power import Pow +from sympy.core.relational import Ne +from sympy.core.singleton import S +from sympy.core.sorting import ordered, default_sort_key +from sympy.core.symbol import Dummy, Symbol +from sympy.functions.elementary.exponential import log, exp +from sympy.functions.elementary.hyperbolic import (cosh, coth, sinh, + tanh) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (atan, sin, cos, + tan, acot, cot, asin, acos) +from .integrals import integrate, Integral +from .heurisch import _symbols +from sympy.polys.polyerrors import PolynomialError +from sympy.polys.polytools import (real_roots, cancel, Poly, gcd, + reduced) +from sympy.polys.rootoftools import RootSum +from sympy.utilities.iterables import numbered_symbols + + +def integer_powers(exprs): + """ + Rewrites a list of expressions as integer multiples of each other. + + Explanation + =========== + + For example, if you have [x, x/2, x**2 + 1, 2*x/3], then you can rewrite + this as [(x/6) * 6, (x/6) * 3, (x**2 + 1) * 1, (x/6) * 4]. This is useful + in the Risch integration algorithm, where we must write exp(x) + exp(x/2) + as (exp(x/2))**2 + exp(x/2), but not as exp(x) + sqrt(exp(x)) (this is + because only the transcendental case is implemented and we therefore cannot + integrate algebraic extensions). The integer multiples returned by this + function for each term are the smallest possible (their content equals 1). + + Returns a list of tuples where the first element is the base term and the + second element is a list of `(item, factor)` terms, where `factor` is the + integer multiplicative factor that must multiply the base term to obtain + the original item. + + The easiest way to understand this is to look at an example: + + >>> from sympy.abc import x + >>> from sympy.integrals.risch import integer_powers + >>> integer_powers([x, x/2, x**2 + 1, 2*x/3]) + [(x/6, [(x, 6), (x/2, 3), (2*x/3, 4)]), (x**2 + 1, [(x**2 + 1, 1)])] + + We can see how this relates to the example at the beginning of the + docstring. It chose x/6 as the first base term. Then, x can be written as + (x/2) * 2, so we get (0, 2), and so on. Now only element (x**2 + 1) + remains, and there are no other terms that can be written as a rational + multiple of that, so we get that it can be written as (x**2 + 1) * 1. + + """ + # Here is the strategy: + + # First, go through each term and determine if it can be rewritten as a + # rational multiple of any of the terms gathered so far. + # cancel(a/b).is_Rational is sufficient for this. If it is a multiple, we + # add its multiple to the dictionary. + + terms = {} + for term in exprs: + for trm, trm_list in terms.items(): + a = cancel(term/trm) + if a.is_Rational: + trm_list.append((term, a)) + break + else: + terms[term] = [(term, S.One)] + + # After we have done this, we have all the like terms together, so we just + # need to find a common denominator so that we can get the base term and + # integer multiples such that each term can be written as an integer + # multiple of the base term, and the content of the integers is 1. + + newterms = {} + for term, term_list in terms.items(): + common_denom = reduce(ilcm, [i.as_numer_denom()[1] for _, i in + term_list]) + newterm = term/common_denom + newmults = [(i, j*common_denom) for i, j in term_list] + newterms[newterm] = newmults + + return sorted(iter(newterms.items()), key=lambda item: item[0].sort_key()) + + +class DifferentialExtension: + """ + A container for all the information relating to a differential extension. + + Explanation + =========== + + The attributes of this object are (see also the docstring of __init__): + + - f: The original (Expr) integrand. + - x: The variable of integration. + - T: List of variables in the extension. + - D: List of derivations in the extension; corresponds to the elements of T. + - fa: Poly of the numerator of the integrand. + - fd: Poly of the denominator of the integrand. + - Tfuncs: Lambda() representations of each element of T (except for x). + For back-substitution after integration. + - backsubs: A (possibly empty) list of further substitutions to be made on + the final integral to make it look more like the integrand. + - exts: + - extargs: + - cases: List of string representations of the cases of T. + - t: The top level extension variable, as defined by the current level + (see level below). + - d: The top level extension derivation, as defined by the current + derivation (see level below). + - case: The string representation of the case of self.d. + (Note that self.T and self.D will always contain the complete extension, + regardless of the level. Therefore, you should ALWAYS use DE.t and DE.d + instead of DE.T[-1] and DE.D[-1]. If you want to have a list of the + derivations or variables only up to the current level, use + DE.D[:len(DE.D) + DE.level + 1] and DE.T[:len(DE.T) + DE.level + 1]. Note + that, in particular, the derivation() function does this.) + + The following are also attributes, but will probably not be useful other + than in internal use: + - newf: Expr form of fa/fd. + - level: The number (between -1 and -len(self.T)) such that + self.T[self.level] == self.t and self.D[self.level] == self.d. + Use the methods self.increment_level() and self.decrement_level() to change + the current level. + """ + # __slots__ is defined mainly so we can iterate over all the attributes + # of the class easily (the memory use doesn't matter too much, since we + # only create one DifferentialExtension per integration). Also, it's nice + # to have a safeguard when debugging. + __slots__ = ('f', 'x', 'T', 'D', 'fa', 'fd', 'Tfuncs', 'backsubs', + 'exts', 'extargs', 'cases', 'case', 't', 'd', 'newf', 'level', + 'ts', 'dummy') + + def __init__(self, f=None, x=None, handle_first='log', dummy=False, extension=None, rewrite_complex=None): + """ + Tries to build a transcendental extension tower from ``f`` with respect to ``x``. + + Explanation + =========== + + If it is successful, creates a DifferentialExtension object with, among + others, the attributes fa, fd, D, T, Tfuncs, and backsubs such that + fa and fd are Polys in T[-1] with rational coefficients in T[:-1], + fa/fd == f, and D[i] is a Poly in T[i] with rational coefficients in + T[:i] representing the derivative of T[i] for each i from 1 to len(T). + Tfuncs is a list of Lambda objects for back replacing the functions + after integrating. Lambda() is only used (instead of lambda) to make + them easier to test and debug. Note that Tfuncs corresponds to the + elements of T, except for T[0] == x, but they should be back-substituted + in reverse order. backsubs is a (possibly empty) back-substitution list + that should be applied on the completed integral to make it look more + like the original integrand. + + If it is unsuccessful, it raises NotImplementedError. + + You can also create an object by manually setting the attributes as a + dictionary to the extension keyword argument. You must include at least + D. Warning, any attribute that is not given will be set to None. The + attributes T, t, d, cases, case, x, and level are set automatically and + do not need to be given. The functions in the Risch Algorithm will NOT + check to see if an attribute is None before using it. This also does not + check to see if the extension is valid (non-algebraic) or even if it is + self-consistent. Therefore, this should only be used for + testing/debugging purposes. + """ + # XXX: If you need to debug this function, set the break point here + + if extension: + if 'D' not in extension: + raise ValueError("At least the key D must be included with " + "the extension flag to DifferentialExtension.") + for attr in extension: + setattr(self, attr, extension[attr]) + + self._auto_attrs() + + return + elif f is None or x is None: + raise ValueError("Either both f and x or a manual extension must " + "be given.") + + if handle_first not in ('log', 'exp'): + raise ValueError("handle_first must be 'log' or 'exp', not %s." % + str(handle_first)) + + # f will be the original function, self.f might change if we reset + # (e.g., we pull out a constant from an exponential) + self.f = f + self.x = x + # setting the default value 'dummy' + self.dummy = dummy + self.reset() + exp_new_extension, log_new_extension = True, True + + # case of 'automatic' choosing + if rewrite_complex is None: + rewrite_complex = I in self.f.atoms() + + if rewrite_complex: + rewritables = { + (sin, cos, cot, tan, sinh, cosh, coth, tanh): exp, + (asin, acos, acot, atan): log, + } + # rewrite the trigonometric components + for candidates, rule in rewritables.items(): + self.newf = self.newf.rewrite(candidates, rule) + self.newf = cancel(self.newf) + else: + if any(i.has(x) for i in self.f.atoms(sin, cos, tan, atan, asin, acos)): + raise NotImplementedError("Trigonometric extensions are not " + "supported (yet!)") + + exps = set() + pows = set() + numpows = set() + sympows = set() + logs = set() + symlogs = set() + + while True: + if self.newf.is_rational_function(*self.T): + break + + if not exp_new_extension and not log_new_extension: + # We couldn't find a new extension on the last pass, so I guess + # we can't do it. + raise NotImplementedError("Couldn't find an elementary " + "transcendental extension for %s. Try using a " % str(f) + + "manual extension with the extension flag.") + + exps, pows, numpows, sympows, log_new_extension = \ + self._rewrite_exps_pows(exps, pows, numpows, sympows, log_new_extension) + + logs, symlogs = self._rewrite_logs(logs, symlogs) + + if handle_first == 'exp' or not log_new_extension: + exp_new_extension = self._exp_part(exps) + if exp_new_extension is None: + # reset and restart + self.f = self.newf + self.reset() + exp_new_extension = True + continue + + if handle_first == 'log' or not exp_new_extension: + log_new_extension = self._log_part(logs) + + self.fa, self.fd = frac_in(self.newf, self.t) + self._auto_attrs() + + return + + def __getattr__(self, attr): + # Avoid AttributeErrors when debugging + if attr not in self.__slots__: + raise AttributeError("%s has no attribute %s" % (repr(self), repr(attr))) + return None + + def _rewrite_exps_pows(self, exps, pows, numpows, + sympows, log_new_extension): + """ + Rewrite exps/pows for better processing. + """ + from .prde import is_deriv_k + + # Pre-preparsing. + ################# + # Get all exp arguments, so we can avoid ahead of time doing + # something like t1 = exp(x), t2 = exp(x/2) == sqrt(t1). + + # Things like sqrt(exp(x)) do not automatically simplify to + # exp(x/2), so they will be viewed as algebraic. The easiest way + # to handle this is to convert all instances of exp(a)**Rational + # to exp(Rational*a) before doing anything else. Note that the + # _exp_part code can generate terms of this form, so we do need to + # do this at each pass (or else modify it to not do that). + + ratpows = [i for i in self.newf.atoms(Pow) + if (isinstance(i.base, exp) and i.exp.is_Rational)] + + ratpows_repl = [ + (i, i.base.base**(i.exp*i.base.exp)) for i in ratpows] + self.backsubs += [(j, i) for i, j in ratpows_repl] + self.newf = self.newf.xreplace(dict(ratpows_repl)) + + # To make the process deterministic, the args are sorted + # so that functions with smaller op-counts are processed first. + # Ties are broken with the default_sort_key. + + # XXX Although the method is deterministic no additional work + # has been done to guarantee that the simplest solution is + # returned and that it would be affected be using different + # variables. Though it is possible that this is the case + # one should know that it has not been done intentionally, so + # further improvements may be possible. + + # TODO: This probably doesn't need to be completely recomputed at + # each pass. + exps = update_sets(exps, self.newf.atoms(exp), + lambda i: i.exp.is_rational_function(*self.T) and + i.exp.has(*self.T)) + pows = update_sets(pows, self.newf.atoms(Pow), + lambda i: i.exp.is_rational_function(*self.T) and + i.exp.has(*self.T)) + numpows = update_sets(numpows, set(pows), + lambda i: not i.base.has(*self.T)) + sympows = update_sets(sympows, set(pows) - set(numpows), + lambda i: i.base.is_rational_function(*self.T) and + not i.exp.is_Integer) + + # The easiest way to deal with non-base E powers is to convert them + # into base E, integrate, and then convert back. + for i in ordered(pows): + old = i + new = exp(i.exp*log(i.base)) + # If exp is ever changed to automatically reduce exp(x*log(2)) + # to 2**x, then this will break. The solution is to not change + # exp to do that :) + if i in sympows: + if i.exp.is_Rational: + raise NotImplementedError("Algebraic extensions are " + "not supported (%s)." % str(i)) + # We can add a**b only if log(a) in the extension, because + # a**b == exp(b*log(a)). + basea, based = frac_in(i.base, self.t) + A = is_deriv_k(basea, based, self) + if A is None: + # Nonelementary monomial (so far) + + # TODO: Would there ever be any benefit from just + # adding log(base) as a new monomial? + # ANSWER: Yes, otherwise we can't integrate x**x (or + # rather prove that it has no elementary integral) + # without first manually rewriting it as exp(x*log(x)) + self.newf = self.newf.xreplace({old: new}) + self.backsubs += [(new, old)] + log_new_extension = self._log_part([log(i.base)]) + exps = update_sets(exps, self.newf.atoms(exp), lambda i: + i.exp.is_rational_function(*self.T) and i.exp.has(*self.T)) + continue + ans, u, const = A + newterm = exp(i.exp*(log(const) + u)) + # Under the current implementation, exp kills terms + # only if they are of the form a*log(x), where a is a + # Number. This case should have already been killed by the + # above tests. Again, if this changes to kill more than + # that, this will break, which maybe is a sign that you + # shouldn't be changing that. Actually, if anything, this + # auto-simplification should be removed. See + # https://groups.google.com/group/sympy/browse_thread/thread/a61d48235f16867f + + self.newf = self.newf.xreplace({i: newterm}) + + elif i not in numpows: + continue + else: + # i in numpows + newterm = new + # TODO: Just put it in self.Tfuncs + self.backsubs.append((new, old)) + self.newf = self.newf.xreplace({old: newterm}) + exps.append(newterm) + + return exps, pows, numpows, sympows, log_new_extension + + def _rewrite_logs(self, logs, symlogs): + """ + Rewrite logs for better processing. + """ + atoms = self.newf.atoms(log) + logs = update_sets(logs, atoms, + lambda i: i.args[0].is_rational_function(*self.T) and + i.args[0].has(*self.T)) + symlogs = update_sets(symlogs, atoms, + lambda i: i.has(*self.T) and i.args[0].is_Pow and + i.args[0].base.is_rational_function(*self.T) and + not i.args[0].exp.is_Integer) + + # We can handle things like log(x**y) by converting it to y*log(x) + # This will fix not only symbolic exponents of the argument, but any + # non-Integer exponent, like log(sqrt(x)). The exponent can also + # depend on x, like log(x**x). + for i in ordered(symlogs): + # Unlike in the exponential case above, we do not ever + # potentially add new monomials (above we had to add log(a)). + # Therefore, there is no need to run any is_deriv functions + # here. Just convert log(a**b) to b*log(a) and let + # log_new_extension() handle it from there. + lbase = log(i.args[0].base) + logs.append(lbase) + new = i.args[0].exp*lbase + self.newf = self.newf.xreplace({i: new}) + self.backsubs.append((new, i)) + + # remove any duplicates + logs = sorted(set(logs), key=default_sort_key) + + return logs, symlogs + + def _auto_attrs(self): + """ + Set attributes that are generated automatically. + """ + if not self.T: + # i.e., when using the extension flag and T isn't given + self.T = [i.gen for i in self.D] + if not self.x: + self.x = self.T[0] + self.cases = [get_case(d, t) for d, t in zip(self.D, self.T)] + self.level = -1 + self.t = self.T[self.level] + self.d = self.D[self.level] + self.case = self.cases[self.level] + + def _exp_part(self, exps): + """ + Try to build an exponential extension. + + Returns + ======= + + Returns True if there was a new extension, False if there was no new + extension but it was able to rewrite the given exponentials in terms + of the existing extension, and None if the entire extension building + process should be restarted. If the process fails because there is no + way around an algebraic extension (e.g., exp(log(x)/2)), it will raise + NotImplementedError. + """ + from .prde import is_log_deriv_k_t_radical + new_extension = False + restart = False + expargs = [i.exp for i in exps] + ip = integer_powers(expargs) + for arg, others in ip: + # Minimize potential problems with algebraic substitution + others.sort(key=lambda i: i[1]) + + arga, argd = frac_in(arg, self.t) + A = is_log_deriv_k_t_radical(arga, argd, self) + + if A is not None: + ans, u, n, const = A + # if n is 1 or -1, it's algebraic, but we can handle it + if n == -1: + # This probably will never happen, because + # Rational.as_numer_denom() returns the negative term in + # the numerator. But in case that changes, reduce it to + # n == 1. + n = 1 + u **= -1 + const *= -1 + ans = [(i, -j) for i, j in ans] + + if n == 1: + # Example: exp(x + x**2) over QQ(x, exp(x), exp(x**2)) + self.newf = self.newf.xreplace({exp(arg): exp(const)*Mul(*[ + u**power for u, power in ans])}) + self.newf = self.newf.xreplace({exp(p*exparg): + exp(const*p) * Mul(*[u**power for u, power in ans]) + for exparg, p in others}) + # TODO: Add something to backsubs to put exp(const*p) + # back together. + + continue + + else: + # Bad news: we have an algebraic radical. But maybe we + # could still avoid it by choosing a different extension. + # For example, integer_powers() won't handle exp(x/2 + 1) + # over QQ(x, exp(x)), but if we pull out the exp(1), it + # will. Or maybe we have exp(x + x**2/2), over + # QQ(x, exp(x), exp(x**2)), which is exp(x)*sqrt(exp(x**2)), + # but if we use QQ(x, exp(x), exp(x**2/2)), then they will + # all work. + # + # So here is what we do: If there is a non-zero const, pull + # it out and retry. Also, if len(ans) > 1, then rewrite + # exp(arg) as the product of exponentials from ans, and + # retry that. If const == 0 and len(ans) == 1, then we + # assume that it would have been handled by either + # integer_powers() or n == 1 above if it could be handled, + # so we give up at that point. For example, you can never + # handle exp(log(x)/2) because it equals sqrt(x). + + if const or len(ans) > 1: + rad = Mul(*[term**(power/n) for term, power in ans]) + self.newf = self.newf.xreplace({exp(p*exparg): + exp(const*p)*rad for exparg, p in others}) + self.newf = self.newf.xreplace(dict(list(zip(reversed(self.T), + reversed([f(self.x) for f in self.Tfuncs]))))) + restart = True + break + else: + # TODO: give algebraic dependence in error string + raise NotImplementedError("Cannot integrate over " + "algebraic extensions.") + + else: + arga, argd = frac_in(arg, self.t) + darga = (argd*derivation(Poly(arga, self.t), self) - + arga*derivation(Poly(argd, self.t), self)) + dargd = argd**2 + darga, dargd = darga.cancel(dargd, include=True) + darg = darga.as_expr()/dargd.as_expr() + self.t = next(self.ts) + self.T.append(self.t) + self.extargs.append(arg) + self.exts.append('exp') + self.D.append(darg.as_poly(self.t, expand=False)*Poly(self.t, + self.t, expand=False)) + if self.dummy: + i = Dummy("i") + else: + i = Symbol('i') + self.Tfuncs += [Lambda(i, exp(arg.subs(self.x, i)))] + self.newf = self.newf.xreplace( + {exp(exparg): self.t**p for exparg, p in others}) + new_extension = True + + if restart: + return None + return new_extension + + def _log_part(self, logs): + """ + Try to build a logarithmic extension. + + Returns + ======= + + Returns True if there was a new extension and False if there was no new + extension but it was able to rewrite the given logarithms in terms + of the existing extension. Unlike with exponential extensions, there + is no way that a logarithm is not transcendental over and cannot be + rewritten in terms of an already existing extension in a non-algebraic + way, so this function does not ever return None or raise + NotImplementedError. + """ + from .prde import is_deriv_k + new_extension = False + logargs = [i.args[0] for i in logs] + for arg in ordered(logargs): + # The log case is easier, because whenever a logarithm is algebraic + # over the base field, it is of the form a1*t1 + ... an*tn + c, + # which is a polynomial, so we can just replace it with that. + # In other words, we don't have to worry about radicals. + arga, argd = frac_in(arg, self.t) + A = is_deriv_k(arga, argd, self) + if A is not None: + ans, u, const = A + newterm = log(const) + u + self.newf = self.newf.xreplace({log(arg): newterm}) + continue + + else: + arga, argd = frac_in(arg, self.t) + darga = (argd*derivation(Poly(arga, self.t), self) - + arga*derivation(Poly(argd, self.t), self)) + dargd = argd**2 + darg = darga.as_expr()/dargd.as_expr() + self.t = next(self.ts) + self.T.append(self.t) + self.extargs.append(arg) + self.exts.append('log') + self.D.append(cancel(darg.as_expr()/arg).as_poly(self.t, + expand=False)) + if self.dummy: + i = Dummy("i") + else: + i = Symbol('i') + self.Tfuncs += [Lambda(i, log(arg.subs(self.x, i)))] + self.newf = self.newf.xreplace({log(arg): self.t}) + new_extension = True + + return new_extension + + @property + def _important_attrs(self): + """ + Returns some of the more important attributes of self. + + Explanation + =========== + + Used for testing and debugging purposes. + + The attributes are (fa, fd, D, T, Tfuncs, backsubs, + exts, extargs). + """ + return (self.fa, self.fd, self.D, self.T, self.Tfuncs, + self.backsubs, self.exts, self.extargs) + + # NOTE: this printing doesn't follow the Python's standard + # eval(repr(DE)) == DE, where DE is the DifferentialExtension object, + # also this printing is supposed to contain all the important + # attributes of a DifferentialExtension object + def __repr__(self): + # no need to have GeneratorType object printed in it + r = [(attr, getattr(self, attr)) for attr in self.__slots__ + if not isinstance(getattr(self, attr), GeneratorType)] + return self.__class__.__name__ + '(dict(%r))' % (r) + + # fancy printing of DifferentialExtension object + def __str__(self): + return (self.__class__.__name__ + '({fa=%s, fd=%s, D=%s})' % + (self.fa, self.fd, self.D)) + + # should only be used for debugging purposes, internally + # f1 = f2 = log(x) at different places in code execution + # may return D1 != D2 as True, since 'level' or other attribute + # may differ + def __eq__(self, other): + for attr in self.__class__.__slots__: + d1, d2 = getattr(self, attr), getattr(other, attr) + if not (isinstance(d1, GeneratorType) or d1 == d2): + return False + return True + + def reset(self): + """ + Reset self to an initial state. Used by __init__. + """ + self.t = self.x + self.T = [self.x] + self.D = [Poly(1, self.x)] + self.level = -1 + self.exts = [None] + self.extargs = [None] + if self.dummy: + self.ts = numbered_symbols('t', cls=Dummy) + else: + # For testing + self.ts = numbered_symbols('t') + # For various things that we change to make things work that we need to + # change back when we are done. + self.backsubs = [] + self.Tfuncs = [] + self.newf = self.f + + def indices(self, extension): + """ + Parameters + ========== + + extension : str + Represents a valid extension type. + + Returns + ======= + + list: A list of indices of 'exts' where extension of + type 'extension' is present. + + Examples + ======== + + >>> from sympy.integrals.risch import DifferentialExtension + >>> from sympy import log, exp + >>> from sympy.abc import x + >>> DE = DifferentialExtension(log(x) + exp(x), x, handle_first='exp') + >>> DE.indices('log') + [2] + >>> DE.indices('exp') + [1] + + """ + return [i for i, ext in enumerate(self.exts) if ext == extension] + + def increment_level(self): + """ + Increment the level of self. + + Explanation + =========== + + This makes the working differential extension larger. self.level is + given relative to the end of the list (-1, -2, etc.), so we do not need + do worry about it when building the extension. + """ + if self.level >= -1: + raise ValueError("The level of the differential extension cannot " + "be incremented any further.") + + self.level += 1 + self.t = self.T[self.level] + self.d = self.D[self.level] + self.case = self.cases[self.level] + return None + + def decrement_level(self): + """ + Decrease the level of self. + + Explanation + =========== + + This makes the working differential extension smaller. self.level is + given relative to the end of the list (-1, -2, etc.), so we do not need + do worry about it when building the extension. + """ + if self.level <= -len(self.T): + raise ValueError("The level of the differential extension cannot " + "be decremented any further.") + + self.level -= 1 + self.t = self.T[self.level] + self.d = self.D[self.level] + self.case = self.cases[self.level] + return None + + +def update_sets(seq, atoms, func): + s = set(seq) + s = atoms.intersection(s) + new = atoms - s + s.update(list(filter(func, new))) + return list(s) + + +class DecrementLevel: + """ + A context manager for decrementing the level of a DifferentialExtension. + """ + __slots__ = ('DE',) + + def __init__(self, DE): + self.DE = DE + return + + def __enter__(self): + self.DE.decrement_level() + + def __exit__(self, exc_type, exc_value, traceback): + self.DE.increment_level() + + +class NonElementaryIntegralException(Exception): + """ + Exception used by subroutines within the Risch algorithm to indicate to one + another that the function being integrated does not have an elementary + integral in the given differential field. + """ + # TODO: Rewrite algorithms below to use this (?) + + # TODO: Pass through information about why the integral was nonelementary, + # and store that in the resulting NonElementaryIntegral somehow. + pass + + +def gcdex_diophantine(a, b, c): + """ + Extended Euclidean Algorithm, Diophantine version. + + Explanation + =========== + + Given ``a``, ``b`` in K[x] and ``c`` in (a, b), the ideal generated by ``a`` and + ``b``, return (s, t) such that s*a + t*b == c and either s == 0 or s.degree() + < b.degree(). + """ + # Extended Euclidean Algorithm (Diophantine Version) pg. 13 + # TODO: This should go in densetools.py. + # XXX: Better name? + + s, g = a.half_gcdex(b) + s *= c.exquo(g) # Inexact division means c is not in (a, b) + if s and s.degree() >= b.degree(): + _, s = s.div(b) + t = (c - s*a).exquo(b) + return (s, t) + + +def frac_in(f, t, *, cancel=False, **kwargs): + """ + Returns the tuple (fa, fd), where fa and fd are Polys in t. + + Explanation + =========== + + This is a common idiom in the Risch Algorithm functions, so we abstract + it out here. ``f`` should be a basic expression, a Poly, or a tuple (fa, fd), + where fa and fd are either basic expressions or Polys, and f == fa/fd. + **kwargs are applied to Poly. + """ + if isinstance(f, tuple): + fa, fd = f + f = fa.as_expr()/fd.as_expr() + fa, fd = f.as_expr().as_numer_denom() + fa, fd = fa.as_poly(t, **kwargs), fd.as_poly(t, **kwargs) + if cancel: + fa, fd = fa.cancel(fd, include=True) + if fa is None or fd is None: + raise ValueError("Could not turn %s into a fraction in %s." % (f, t)) + return (fa, fd) + + +def as_poly_1t(p, t, z): + """ + (Hackish) way to convert an element ``p`` of K[t, 1/t] to K[t, z]. + + In other words, ``z == 1/t`` will be a dummy variable that Poly can handle + better. + + See issue 5131. + + Examples + ======== + + >>> from sympy import random_poly + >>> from sympy.integrals.risch import as_poly_1t + >>> from sympy.abc import x, z + + >>> p1 = random_poly(x, 10, -10, 10) + >>> p2 = random_poly(x, 10, -10, 10) + >>> p = p1 + p2.subs(x, 1/x) + >>> as_poly_1t(p, x, z).as_expr().subs(z, 1/x) == p + True + """ + # TODO: Use this on the final result. That way, we can avoid answers like + # (...)*exp(-x). + pa, pd = frac_in(p, t, cancel=True) + if not pd.is_monomial: + # XXX: Is there a better Poly exception that we could raise here? + # Either way, if you see this (from the Risch Algorithm) it indicates + # a bug. + raise PolynomialError("%s is not an element of K[%s, 1/%s]." % (p, t, t)) + + t_part, remainder = pa.div(pd) + + ans = t_part.as_poly(t, z, expand=False) + + if remainder: + one = remainder.one + tp = t*one + r = pd.degree() - remainder.degree() + z_part = remainder.transform(one, tp) * tp**r + z_part = z_part.replace(t, z).to_field().quo_ground(pd.LC()) + ans += z_part.as_poly(t, z, expand=False) + + return ans + + +def derivation(p, DE, coefficientD=False, basic=False): + """ + Computes Dp. + + Explanation + =========== + + Given the derivation D with D = d/dx and p is a polynomial in t over + K(x), return Dp. + + If coefficientD is True, it computes the derivation kD + (kappaD), which is defined as kD(sum(ai*Xi**i, (i, 0, n))) == + sum(Dai*Xi**i, (i, 1, n)) (Definition 3.2.2, page 80). X in this case is + T[-1], so coefficientD computes the derivative just with respect to T[:-1], + with T[-1] treated as a constant. + + If ``basic=True``, the returns a Basic expression. Elements of D can still be + instances of Poly. + """ + if basic: + r = 0 + else: + r = Poly(0, DE.t) + + t = DE.t + if coefficientD: + if DE.level <= -len(DE.T): + # 'base' case, the answer is 0. + return r + DE.decrement_level() + + D = DE.D[:len(DE.D) + DE.level + 1] + T = DE.T[:len(DE.T) + DE.level + 1] + + for d, v in zip(D, T): + pv = p.as_poly(v) + if pv is None or basic: + pv = p.as_expr() + + if basic: + r += d.as_expr()*pv.diff(v) + else: + r += (d.as_expr()*pv.diff(v).as_expr()).as_poly(t) + + if basic: + r = cancel(r) + if coefficientD: + DE.increment_level() + + return r + + +def get_case(d, t): + """ + Returns the type of the derivation d. + + Returns one of {'exp', 'tan', 'base', 'primitive', 'other_linear', + 'other_nonlinear'}. + """ + if not d.expr.has(t): + if d.is_one: + return 'base' + return 'primitive' + if d.rem(Poly(t, t)).is_zero: + return 'exp' + if d.rem(Poly(1 + t**2, t)).is_zero: + return 'tan' + if d.degree(t) > 1: + return 'other_nonlinear' + return 'other_linear' + + +def splitfactor(p, DE, coefficientD=False, z=None): + """ + Splitting factorization. + + Explanation + =========== + + Given a derivation D on k[t] and ``p`` in k[t], return (p_n, p_s) in + k[t] x k[t] such that p = p_n*p_s, p_s is special, and each square + factor of p_n is normal. + + Page. 100 + """ + kinv = [1/x for x in DE.T[:DE.level]] + if z: + kinv.append(z) + + One = Poly(1, DE.t, domain=p.get_domain()) + Dp = derivation(p, DE, coefficientD=coefficientD) + # XXX: Is this right? + if p.is_zero: + return (p, One) + + if not p.expr.has(DE.t): + s = p.as_poly(*kinv).gcd(Dp.as_poly(*kinv)).as_poly(DE.t) + n = p.exquo(s) + return (n, s) + + if not Dp.is_zero: + h = p.gcd(Dp).to_field() + g = p.gcd(p.diff(DE.t)).to_field() + s = h.exquo(g) + + if s.degree(DE.t) == 0: + return (p, One) + + q_split = splitfactor(p.exquo(s), DE, coefficientD=coefficientD) + + return (q_split[0], q_split[1]*s) + else: + return (p, One) + + +def splitfactor_sqf(p, DE, coefficientD=False, z=None, basic=False): + """ + Splitting Square-free Factorization. + + Explanation + =========== + + Given a derivation D on k[t] and ``p`` in k[t], returns (N1, ..., Nm) + and (S1, ..., Sm) in k[t]^m such that p = + (N1*N2**2*...*Nm**m)*(S1*S2**2*...*Sm**m) is a splitting + factorization of ``p`` and the Ni and Si are square-free and coprime. + """ + # TODO: This algorithm appears to be faster in every case + # TODO: Verify this and splitfactor() for multiple extensions + kkinv = [1/x for x in DE.T[:DE.level]] + DE.T[:DE.level] + if z: + kkinv = [z] + + S = [] + N = [] + p_sqf = p.sqf_list_include() + if p.is_zero: + return (((p, 1),), ()) + + for pi, i in p_sqf: + Si = pi.as_poly(*kkinv).gcd(derivation(pi, DE, + coefficientD=coefficientD,basic=basic).as_poly(*kkinv)).as_poly(DE.t) + pi = Poly(pi, DE.t) + Si = Poly(Si, DE.t) + Ni = pi.exquo(Si) + if not Si.is_one: + S.append((Si, i)) + if not Ni.is_one: + N.append((Ni, i)) + + return (tuple(N), tuple(S)) + + +def canonical_representation(a, d, DE): + """ + Canonical Representation. + + Explanation + =========== + + Given a derivation D on k[t] and f = a/d in k(t), return (f_p, f_s, + f_n) in k[t] x k(t) x k(t) such that f = f_p + f_s + f_n is the + canonical representation of f (f_p is a polynomial, f_s is reduced + (has a special denominator), and f_n is simple (has a normal + denominator). + """ + # Make d monic + l = Poly(1/d.LC(), DE.t) + a, d = a.mul(l), d.mul(l) + + q, r = a.div(d) + dn, ds = splitfactor(d, DE) + + b, c = gcdex_diophantine(dn.as_poly(DE.t), ds.as_poly(DE.t), r.as_poly(DE.t)) + b, c = b.as_poly(DE.t), c.as_poly(DE.t) + + return (q, (b, ds), (c, dn)) + + +def hermite_reduce(a, d, DE): + """ + Hermite Reduction - Mack's Linear Version. + + Given a derivation D on k(t) and f = a/d in k(t), returns g, h, r in + k(t) such that f = Dg + h + r, h is simple, and r is reduced. + + """ + # Make d monic + l = Poly(1/d.LC(), DE.t) + a, d = a.mul(l), d.mul(l) + + fp, fs, fn = canonical_representation(a, d, DE) + a, d = fn + l = Poly(1/d.LC(), DE.t) + a, d = a.mul(l), d.mul(l) + + ga = Poly(0, DE.t) + gd = Poly(1, DE.t) + + dd = derivation(d, DE) + dm = gcd(d.to_field(), dd.to_field()).as_poly(DE.t) + ds, _ = d.div(dm) + + while dm.degree(DE.t) > 0: + + ddm = derivation(dm, DE) + dm2 = gcd(dm.to_field(), ddm.to_field()) + dms, _ = dm.div(dm2) + ds_ddm = ds.mul(ddm) + ds_ddm_dm, _ = ds_ddm.div(dm) + + b, c = gcdex_diophantine(-ds_ddm_dm.as_poly(DE.t), + dms.as_poly(DE.t), a.as_poly(DE.t)) + b, c = b.as_poly(DE.t), c.as_poly(DE.t) + + db = derivation(b, DE).as_poly(DE.t) + ds_dms, _ = ds.div(dms) + a = c.as_poly(DE.t) - db.mul(ds_dms).as_poly(DE.t) + + ga = ga*dm + b*gd + gd = gd*dm + ga, gd = ga.cancel(gd, include=True) + dm = dm2 + + q, r = a.div(ds) + ga, gd = ga.cancel(gd, include=True) + + r, d = r.cancel(ds, include=True) + rra = q*fs[1] + fp*fs[1] + fs[0] + rrd = fs[1] + rra, rrd = rra.cancel(rrd, include=True) + + return ((ga, gd), (r, d), (rra, rrd)) + + +def polynomial_reduce(p, DE): + """ + Polynomial Reduction. + + Explanation + =========== + + Given a derivation D on k(t) and p in k[t] where t is a nonlinear + monomial over k, return q, r in k[t] such that p = Dq + r, and + deg(r) < deg_t(Dt). + """ + q = Poly(0, DE.t) + while p.degree(DE.t) >= DE.d.degree(DE.t): + m = p.degree(DE.t) - DE.d.degree(DE.t) + 1 + q0 = Poly(DE.t**m, DE.t).mul(Poly(p.as_poly(DE.t).LC()/ + (m*DE.d.LC()), DE.t)) + q += q0 + p = p - derivation(q0, DE) + + return (q, p) + + +def laurent_series(a, d, F, n, DE): + """ + Contribution of ``F`` to the full partial fraction decomposition of A/D. + + Explanation + =========== + + Given a field K of characteristic 0 and ``A``,``D``,``F`` in K[x] with D monic, + nonzero, coprime with A, and ``F`` the factor of multiplicity n in the square- + free factorization of D, return the principal parts of the Laurent series of + A/D at all the zeros of ``F``. + """ + if F.degree()==0: + return 0 + Z = _symbols('z', n) + z = Symbol('z') + Z.insert(0, z) + delta_a = Poly(0, DE.t) + delta_d = Poly(1, DE.t) + + E = d.quo(F**n) + ha, hd = (a, E*Poly(z**n, DE.t)) + dF = derivation(F,DE) + B, _ = gcdex_diophantine(E, F, Poly(1,DE.t)) + C, _ = gcdex_diophantine(dF, F, Poly(1,DE.t)) + + # initialization + F_store = F + V, DE_D_list, H_list= [], [], [] + + for j in range(0, n): + # jth derivative of z would be substituted with dfnth/(j+1) where dfnth =(d^n)f/(dx)^n + F_store = derivation(F_store, DE) + v = (F_store.as_expr())/(j + 1) + V.append(v) + DE_D_list.append(Poly(Z[j + 1],Z[j])) + + DE_new = DifferentialExtension(extension = {'D': DE_D_list}) #a differential indeterminate + for j in range(0, n): + zEha = Poly(z**(n + j), DE.t)*E**(j + 1)*ha + zEhd = hd + Pa, Pd = cancel((zEha, zEhd))[1], cancel((zEha, zEhd))[2] + Q = Pa.quo(Pd) + for i in range(0, j + 1): + Q = Q.subs(Z[i], V[i]) + Dha = (hd*derivation(ha, DE, basic=True).as_poly(DE.t) + + ha*derivation(hd, DE, basic=True).as_poly(DE.t) + + hd*derivation(ha, DE_new, basic=True).as_poly(DE.t) + + ha*derivation(hd, DE_new, basic=True).as_poly(DE.t)) + Dhd = Poly(j + 1, DE.t)*hd**2 + ha, hd = Dha, Dhd + + Ff, _ = F.div(gcd(F, Q)) + F_stara, F_stard = frac_in(Ff, DE.t) + if F_stara.degree(DE.t) - F_stard.degree(DE.t) > 0: + QBC = Poly(Q, DE.t)*B**(1 + j)*C**(n + j) + H = QBC + H_list.append(H) + H = (QBC*F_stard).rem(F_stara) + alphas = real_roots(F_stara) + for alpha in list(alphas): + delta_a = delta_a*Poly((DE.t - alpha)**(n - j), DE.t) + Poly(H.eval(alpha), DE.t) + delta_d = delta_d*Poly((DE.t - alpha)**(n - j), DE.t) + return (delta_a, delta_d, H_list) + + +def recognize_derivative(a, d, DE, z=None): + """ + Compute the squarefree factorization of the denominator of f + and for each Di the polynomial H in K[x] (see Theorem 2.7.1), using the + LaurentSeries algorithm. Write Di = GiEi where Gj = gcd(Hn, Di) and + gcd(Ei,Hn) = 1. Since the residues of f at the roots of Gj are all 0, and + the residue of f at a root alpha of Ei is Hi(a) != 0, f is the derivative of a + rational function if and only if Ei = 1 for each i, which is equivalent to + Di | H[-1] for each i. + """ + flag =True + a, d = a.cancel(d, include=True) + _, r = a.div(d) + Np, Sp = splitfactor_sqf(d, DE, coefficientD=True, z=z) + + j = 1 + for s, _ in Sp: + delta_a, delta_d, H = laurent_series(r, d, s, j, DE) + g = gcd(d, H[-1]).as_poly() + if g is not d: + flag = False + break + j = j + 1 + return flag + + +def recognize_log_derivative(a, d, DE, z=None): + """ + There exists a v in K(x)* such that f = dv/v + where f a rational function if and only if f can be written as f = A/D + where D is squarefree,deg(A) < deg(D), gcd(A, D) = 1, + and all the roots of the Rothstein-Trager resultant are integers. In that case, + any of the Rothstein-Trager, Lazard-Rioboo-Trager or Czichowski algorithm + produces u in K(x) such that du/dx = uf. + """ + + z = z or Dummy('z') + a, d = a.cancel(d, include=True) + _, a = a.div(d) + + pz = Poly(z, DE.t) + Dd = derivation(d, DE) + q = a - pz*Dd + r, _ = d.resultant(q, includePRS=True) + r = Poly(r, z) + Np, Sp = splitfactor_sqf(r, DE, coefficientD=True, z=z) + + for s, _ in Sp: + # TODO also consider the complex roots which should + # turn the flag false + a = real_roots(s.as_poly(z)) + + if not all(j.is_Integer for j in a): + return False + return True + +def residue_reduce(a, d, DE, z=None, invert=True): + """ + Lazard-Rioboo-Rothstein-Trager resultant reduction. + + Explanation + =========== + + Given a derivation ``D`` on k(t) and f in k(t) simple, return g + elementary over k(t) and a Boolean b in {True, False} such that f - + Dg in k[t] if b == True or f + h and f + h - Dg do not have an + elementary integral over k(t) for any h in k (reduced) if b == + False. + + Returns (G, b), where G is a tuple of tuples of the form (s_i, S_i), + such that g = Add(*[RootSum(s_i, lambda z: z*log(S_i(z, t))) for + S_i, s_i in G]). f - Dg is the remaining integral, which is elementary + only if b == True, and hence the integral of f is elementary only if + b == True. + + f - Dg is not calculated in this function because that would require + explicitly calculating the RootSum. Use residue_reduce_derivation(). + """ + # TODO: Use log_to_atan() from rationaltools.py + # If r = residue_reduce(...), then the logarithmic part is given by: + # sum([RootSum(a[0].as_poly(z), lambda i: i*log(a[1].as_expr()).subs(z, + # i)).subs(t, log(x)) for a in r[0]]) + + z = z or Dummy('z') + a, d = a.cancel(d, include=True) + a, d = a.to_field().mul_ground(1/d.LC()), d.to_field().mul_ground(1/d.LC()) + kkinv = [1/x for x in DE.T[:DE.level]] + DE.T[:DE.level] + + if a.is_zero: + return ([], True) + _, a = a.div(d) + + pz = Poly(z, DE.t) + + Dd = derivation(d, DE) + q = a - pz*Dd + + if Dd.degree(DE.t) <= d.degree(DE.t): + r, R = d.resultant(q, includePRS=True) + else: + r, R = q.resultant(d, includePRS=True) + + R_map, H = {}, [] + for i in R: + R_map[i.degree()] = i + + r = Poly(r, z) + Np, Sp = splitfactor_sqf(r, DE, coefficientD=True, z=z) + + for s, i in Sp: + if i == d.degree(DE.t): + s = Poly(s, z).monic() + H.append((s, d)) + else: + h = R_map.get(i) + if h is None: + continue + h_lc = Poly(h.as_poly(DE.t).LC(), DE.t, field=True) + + h_lc_sqf = h_lc.sqf_list_include(all=True) + + for a, j in h_lc_sqf: + h = Poly(h, DE.t, field=True).exquo(Poly(gcd(a, s**j, *kkinv), + DE.t)) + + s = Poly(s, z).monic() + + if invert: + h_lc = Poly(h.as_poly(DE.t).LC(), DE.t, field=True, expand=False) + inv, coeffs = h_lc.as_poly(z, field=True).invert(s), [S.One] + + for coeff in h.coeffs()[1:]: + L = reduced(inv*coeff.as_poly(inv.gens), [s])[1] + coeffs.append(L.as_expr()) + + h = Poly(dict(list(zip(h.monoms(), coeffs))), DE.t) + + H.append((s, h)) + + b = not any(cancel(i.as_expr()).has(DE.t, z) for i, _ in Np) + + return (H, b) + + +def residue_reduce_to_basic(H, DE, z): + """ + Converts the tuple returned by residue_reduce() into a Basic expression. + """ + # TODO: check what Lambda does with RootOf + i = Dummy('i') + s = list(zip(reversed(DE.T), reversed([f(DE.x) for f in DE.Tfuncs]))) + + return sum(RootSum(a[0].as_poly(z), Lambda(i, i*log(a[1].as_expr()).subs( + {z: i}).subs(s))) for a in H) + + +def residue_reduce_derivation(H, DE, z): + """ + Computes the derivation of an expression returned by residue_reduce(). + + In general, this is a rational function in t, so this returns an + as_expr() result. + """ + # TODO: verify that this is correct for multiple extensions + i = Dummy('i') + return S(sum(RootSum(a[0].as_poly(z), Lambda(i, i*derivation(a[1], + DE).as_expr().subs(z, i)/a[1].as_expr().subs(z, i))) for a in H)) + + +def integrate_primitive_polynomial(p, DE): + """ + Integration of primitive polynomials. + + Explanation + =========== + + Given a primitive monomial t over k, and ``p`` in k[t], return q in k[t], + r in k, and a bool b in {True, False} such that r = p - Dq is in k if b is + True, or r = p - Dq does not have an elementary integral over k(t) if b is + False. + """ + Zero = Poly(0, DE.t) + q = Poly(0, DE.t) + + if not p.expr.has(DE.t): + return (Zero, p, True) + + from .prde import limited_integrate + while True: + if not p.expr.has(DE.t): + return (q, p, True) + + Dta, Dtb = frac_in(DE.d, DE.T[DE.level - 1]) + + with DecrementLevel(DE): # We had better be integrating the lowest extension (x) + # with ratint(). + a = p.LC() + aa, ad = frac_in(a, DE.t) + + try: + rv = limited_integrate(aa, ad, [(Dta, Dtb)], DE) + if rv is None: + raise NonElementaryIntegralException + (ba, bd), c = rv + except NonElementaryIntegralException: + return (q, p, False) + + m = p.degree(DE.t) + q0 = c[0].as_poly(DE.t)*Poly(DE.t**(m + 1)/(m + 1), DE.t) + \ + (ba.as_expr()/bd.as_expr()).as_poly(DE.t)*Poly(DE.t**m, DE.t) + + p = p - derivation(q0, DE) + q = q + q0 + + +def integrate_primitive(a, d, DE, z=None): + """ + Integration of primitive functions. + + Explanation + =========== + + Given a primitive monomial t over k and f in k(t), return g elementary over + k(t), i in k(t), and b in {True, False} such that i = f - Dg is in k if b + is True or i = f - Dg does not have an elementary integral over k(t) if b + is False. + + This function returns a Basic expression for the first argument. If b is + True, the second argument is Basic expression in k to recursively integrate. + If b is False, the second argument is an unevaluated Integral, which has + been proven to be nonelementary. + """ + # XXX: a and d must be canceled, or this might return incorrect results + z = z or Dummy("z") + s = list(zip(reversed(DE.T), reversed([f(DE.x) for f in DE.Tfuncs]))) + + g1, h, r = hermite_reduce(a, d, DE) + g2, b = residue_reduce(h[0], h[1], DE, z=z) + if not b: + i = cancel(a.as_expr()/d.as_expr() - (g1[1]*derivation(g1[0], DE) - + g1[0]*derivation(g1[1], DE)).as_expr()/(g1[1]**2).as_expr() - + residue_reduce_derivation(g2, DE, z)) + i = NonElementaryIntegral(cancel(i).subs(s), DE.x) + return ((g1[0].as_expr()/g1[1].as_expr()).subs(s) + + residue_reduce_to_basic(g2, DE, z), i, b) + + # h - Dg2 + r + p = cancel(h[0].as_expr()/h[1].as_expr() - residue_reduce_derivation(g2, + DE, z) + r[0].as_expr()/r[1].as_expr()) + p = p.as_poly(DE.t) + + q, i, b = integrate_primitive_polynomial(p, DE) + + ret = ((g1[0].as_expr()/g1[1].as_expr() + q.as_expr()).subs(s) + + residue_reduce_to_basic(g2, DE, z)) + if not b: + # TODO: This does not do the right thing when b is False + i = NonElementaryIntegral(cancel(i.as_expr()).subs(s), DE.x) + else: + i = cancel(i.as_expr()) + + return (ret, i, b) + + +def integrate_hyperexponential_polynomial(p, DE, z): + """ + Integration of hyperexponential polynomials. + + Explanation + =========== + + Given a hyperexponential monomial t over k and ``p`` in k[t, 1/t], return q in + k[t, 1/t] and a bool b in {True, False} such that p - Dq in k if b is True, + or p - Dq does not have an elementary integral over k(t) if b is False. + """ + t1 = DE.t + dtt = DE.d.exquo(Poly(DE.t, DE.t)) + qa = Poly(0, DE.t) + qd = Poly(1, DE.t) + b = True + + if p.is_zero: + return(qa, qd, b) + + from sympy.integrals.rde import rischDE + + with DecrementLevel(DE): + for i in range(-p.degree(z), p.degree(t1) + 1): + if not i: + continue + elif i < 0: + # If you get AttributeError: 'NoneType' object has no attribute 'nth' + # then this should really not have expand=False + # But it shouldn't happen because p is already a Poly in t and z + a = p.as_poly(z, expand=False).nth(-i) + else: + # If you get AttributeError: 'NoneType' object has no attribute 'nth' + # then this should really not have expand=False + a = p.as_poly(t1, expand=False).nth(i) + + aa, ad = frac_in(a, DE.t, field=True) + aa, ad = aa.cancel(ad, include=True) + iDt = Poly(i, t1)*dtt + iDta, iDtd = frac_in(iDt, DE.t, field=True) + try: + va, vd = rischDE(iDta, iDtd, Poly(aa, DE.t), Poly(ad, DE.t), DE) + va, vd = frac_in((va, vd), t1, cancel=True) + except NonElementaryIntegralException: + b = False + else: + qa = qa*vd + va*Poly(t1**i)*qd + qd *= vd + + return (qa, qd, b) + + +def integrate_hyperexponential(a, d, DE, z=None, conds='piecewise'): + """ + Integration of hyperexponential functions. + + Explanation + =========== + + Given a hyperexponential monomial t over k and f in k(t), return g + elementary over k(t), i in k(t), and a bool b in {True, False} such that + i = f - Dg is in k if b is True or i = f - Dg does not have an elementary + integral over k(t) if b is False. + + This function returns a Basic expression for the first argument. If b is + True, the second argument is Basic expression in k to recursively integrate. + If b is False, the second argument is an unevaluated Integral, which has + been proven to be nonelementary. + """ + # XXX: a and d must be canceled, or this might return incorrect results + z = z or Dummy("z") + s = list(zip(reversed(DE.T), reversed([f(DE.x) for f in DE.Tfuncs]))) + + g1, h, r = hermite_reduce(a, d, DE) + g2, b = residue_reduce(h[0], h[1], DE, z=z) + if not b: + i = cancel(a.as_expr()/d.as_expr() - (g1[1]*derivation(g1[0], DE) - + g1[0]*derivation(g1[1], DE)).as_expr()/(g1[1]**2).as_expr() - + residue_reduce_derivation(g2, DE, z)) + i = NonElementaryIntegral(cancel(i.subs(s)), DE.x) + return ((g1[0].as_expr()/g1[1].as_expr()).subs(s) + + residue_reduce_to_basic(g2, DE, z), i, b) + + # p should be a polynomial in t and 1/t, because Sirr == k[t, 1/t] + # h - Dg2 + r + p = cancel(h[0].as_expr()/h[1].as_expr() - residue_reduce_derivation(g2, + DE, z) + r[0].as_expr()/r[1].as_expr()) + pp = as_poly_1t(p, DE.t, z) + + qa, qd, b = integrate_hyperexponential_polynomial(pp, DE, z) + + i = pp.nth(0, 0) + + ret = ((g1[0].as_expr()/g1[1].as_expr()).subs(s) \ + + residue_reduce_to_basic(g2, DE, z)) + + qas = qa.as_expr().subs(s) + qds = qd.as_expr().subs(s) + if conds == 'piecewise' and DE.x not in qds.free_symbols: + # We have to be careful if the exponent is S.Zero! + + # XXX: Does qd = 0 always necessarily correspond to the exponential + # equaling 1? + ret += Piecewise( + (qas/qds, Ne(qds, 0)), + (integrate((p - i).subs(DE.t, 1).subs(s), DE.x), True) + ) + else: + ret += qas/qds + + if not b: + i = p - (qd*derivation(qa, DE) - qa*derivation(qd, DE)).as_expr()/\ + (qd**2).as_expr() + i = NonElementaryIntegral(cancel(i).subs(s), DE.x) + return (ret, i, b) + + +def integrate_hypertangent_polynomial(p, DE): + """ + Integration of hypertangent polynomials. + + Explanation + =========== + + Given a differential field k such that sqrt(-1) is not in k, a + hypertangent monomial t over k, and p in k[t], return q in k[t] and + c in k such that p - Dq - c*D(t**2 + 1)/(t**1 + 1) is in k and p - + Dq does not have an elementary integral over k(t) if Dc != 0. + """ + # XXX: Make sure that sqrt(-1) is not in k. + q, r = polynomial_reduce(p, DE) + a = DE.d.exquo(Poly(DE.t**2 + 1, DE.t)) + c = Poly(r.nth(1)/(2*a.as_expr()), DE.t) + return (q, c) + + +def integrate_nonlinear_no_specials(a, d, DE, z=None): + """ + Integration of nonlinear monomials with no specials. + + Explanation + =========== + + Given a nonlinear monomial t over k such that Sirr ({p in k[t] | p is + special, monic, and irreducible}) is empty, and f in k(t), returns g + elementary over k(t) and a Boolean b in {True, False} such that f - Dg is + in k if b == True, or f - Dg does not have an elementary integral over k(t) + if b == False. + + This function is applicable to all nonlinear extensions, but in the case + where it returns b == False, it will only have proven that the integral of + f - Dg is nonelementary if Sirr is empty. + + This function returns a Basic expression. + """ + # TODO: Integral from k? + # TODO: split out nonelementary integral + # XXX: a and d must be canceled, or this might not return correct results + z = z or Dummy("z") + s = list(zip(reversed(DE.T), reversed([f(DE.x) for f in DE.Tfuncs]))) + + g1, h, r = hermite_reduce(a, d, DE) + g2, b = residue_reduce(h[0], h[1], DE, z=z) + if not b: + return ((g1[0].as_expr()/g1[1].as_expr()).subs(s) + + residue_reduce_to_basic(g2, DE, z), b) + + # Because f has no specials, this should be a polynomial in t, or else + # there is a bug. + p = cancel(h[0].as_expr()/h[1].as_expr() - residue_reduce_derivation(g2, + DE, z).as_expr() + r[0].as_expr()/r[1].as_expr()).as_poly(DE.t) + q1, q2 = polynomial_reduce(p, DE) + + if q2.expr.has(DE.t): + b = False + else: + b = True + + ret = (cancel(g1[0].as_expr()/g1[1].as_expr() + q1.as_expr()).subs(s) + + residue_reduce_to_basic(g2, DE, z)) + return (ret, b) + + +class NonElementaryIntegral(Integral): + """ + Represents a nonelementary Integral. + + Explanation + =========== + + If the result of integrate() is an instance of this class, it is + guaranteed to be nonelementary. Note that integrate() by default will try + to find any closed-form solution, even in terms of special functions which + may themselves not be elementary. To make integrate() only give + elementary solutions, or, in the cases where it can prove the integral to + be nonelementary, instances of this class, use integrate(risch=True). + In this case, integrate() may raise NotImplementedError if it cannot make + such a determination. + + integrate() uses the deterministic Risch algorithm to integrate elementary + functions or prove that they have no elementary integral. In some cases, + this algorithm can split an integral into an elementary and nonelementary + part, so that the result of integrate will be the sum of an elementary + expression and a NonElementaryIntegral. + + Examples + ======== + + >>> from sympy import integrate, exp, log, Integral + >>> from sympy.abc import x + + >>> a = integrate(exp(-x**2), x, risch=True) + >>> print(a) + Integral(exp(-x**2), x) + >>> type(a) + + + >>> expr = (2*log(x)**2 - log(x) - x**2)/(log(x)**3 - x**2*log(x)) + >>> b = integrate(expr, x, risch=True) + >>> print(b) + -log(-x + log(x))/2 + log(x + log(x))/2 + Integral(1/log(x), x) + >>> type(b.atoms(Integral).pop()) + + + """ + # TODO: This is useful in and of itself, because isinstance(result, + # NonElementaryIntegral) will tell if the integral has been proven to be + # elementary. But should we do more? Perhaps a no-op .doit() if + # elementary=True? Or maybe some information on why the integral is + # nonelementary. + pass + + +def risch_integrate(f, x, extension=None, handle_first='log', + separate_integral=False, rewrite_complex=None, + conds='piecewise'): + r""" + The Risch Integration Algorithm. + + Explanation + =========== + + Only transcendental functions are supported. Currently, only exponentials + and logarithms are supported, but support for trigonometric functions is + forthcoming. + + If this function returns an unevaluated Integral in the result, it means + that it has proven that integral to be nonelementary. Any errors will + result in raising NotImplementedError. The unevaluated Integral will be + an instance of NonElementaryIntegral, a subclass of Integral. + + handle_first may be either 'exp' or 'log'. This changes the order in + which the extension is built, and may result in a different (but + equivalent) solution (for an example of this, see issue 5109). It is also + possible that the integral may be computed with one but not the other, + because not all cases have been implemented yet. It defaults to 'log' so + that the outer extension is exponential when possible, because more of the + exponential case has been implemented. + + If ``separate_integral`` is ``True``, the result is returned as a tuple (ans, i), + where the integral is ans + i, ans is elementary, and i is either a + NonElementaryIntegral or 0. This useful if you want to try further + integrating the NonElementaryIntegral part using other algorithms to + possibly get a solution in terms of special functions. It is False by + default. + + Examples + ======== + + >>> from sympy.integrals.risch import risch_integrate + >>> from sympy import exp, log, pprint + >>> from sympy.abc import x + + First, we try integrating exp(-x**2). Except for a constant factor of + 2/sqrt(pi), this is the famous error function. + + >>> pprint(risch_integrate(exp(-x**2), x)) + / + | + | 2 + | -x + | e dx + | + / + + The unevaluated Integral in the result means that risch_integrate() has + proven that exp(-x**2) does not have an elementary anti-derivative. + + In many cases, risch_integrate() can split out the elementary + anti-derivative part from the nonelementary anti-derivative part. + For example, + + >>> pprint(risch_integrate((2*log(x)**2 - log(x) - x**2)/(log(x)**3 - + ... x**2*log(x)), x)) + / + | + log(-x + log(x)) log(x + log(x)) | 1 + - ---------------- + --------------- + | ------ dx + 2 2 | log(x) + | + / + + This means that it has proven that the integral of 1/log(x) is + nonelementary. This function is also known as the logarithmic integral, + and is often denoted as Li(x). + + risch_integrate() currently only accepts purely transcendental functions + with exponentials and logarithms, though note that this can include + nested exponentials and logarithms, as well as exponentials with bases + other than E. + + >>> pprint(risch_integrate(exp(x)*exp(exp(x)), x)) + / x\ + \e / + e + >>> pprint(risch_integrate(exp(exp(x)), x)) + / + | + | / x\ + | \e / + | e dx + | + / + + >>> pprint(risch_integrate(x*x**x*log(x) + x**x + x*x**x, x)) + x + x*x + >>> pprint(risch_integrate(x**x, x)) + / + | + | x + | x dx + | + / + + >>> pprint(risch_integrate(-1/(x*log(x)*log(log(x))**2), x)) + 1 + ----------- + log(log(x)) + + """ + f = S(f) + + DE = extension or DifferentialExtension(f, x, handle_first=handle_first, + dummy=True, rewrite_complex=rewrite_complex) + fa, fd = DE.fa, DE.fd + + result = S.Zero + for case in reversed(DE.cases): + if not fa.expr.has(DE.t) and not fd.expr.has(DE.t) and not case == 'base': + DE.decrement_level() + fa, fd = frac_in((fa, fd), DE.t) + continue + + fa, fd = fa.cancel(fd, include=True) + if case == 'exp': + ans, i, b = integrate_hyperexponential(fa, fd, DE, conds=conds) + elif case == 'primitive': + ans, i, b = integrate_primitive(fa, fd, DE) + elif case == 'base': + # XXX: We can't call ratint() directly here because it doesn't + # handle polynomials correctly. + ans = integrate(fa.as_expr()/fd.as_expr(), DE.x, risch=False) + b = False + i = S.Zero + else: + raise NotImplementedError("Only exponential and logarithmic " + "extensions are currently supported.") + + result += ans + if b: + DE.decrement_level() + fa, fd = frac_in(i, DE.t) + else: + result = result.subs(DE.backsubs) + if not i.is_zero: + i = NonElementaryIntegral(i.function.subs(DE.backsubs),i.limits) + if not separate_integral: + result += i + return result + else: + + if isinstance(i, NonElementaryIntegral): + return (result, i) + else: + return (result, 0) diff --git a/.venv/lib/python3.13/site-packages/sympy/integrals/singularityfunctions.py b/.venv/lib/python3.13/site-packages/sympy/integrals/singularityfunctions.py new file mode 100644 index 0000000000000000000000000000000000000000..3e33d0542c45b67b193f17e00c25837f3a82109a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/integrals/singularityfunctions.py @@ -0,0 +1,63 @@ +from sympy.functions import SingularityFunction, DiracDelta +from sympy.integrals import integrate + + +def singularityintegrate(f, x): + """ + This function handles the indefinite integrations of Singularity functions. + The ``integrate`` function calls this function internally whenever an + instance of SingularityFunction is passed as argument. + + Explanation + =========== + + The idea for integration is the following: + + - If we are dealing with a SingularityFunction expression, + i.e. ``SingularityFunction(x, a, n)``, we just return + ``SingularityFunction(x, a, n + 1)/(n + 1)`` if ``n >= 0`` and + ``SingularityFunction(x, a, n + 1)`` if ``n < 0``. + + - If the node is a multiplication or power node having a + SingularityFunction term we rewrite the whole expression in terms of + Heaviside and DiracDelta and then integrate the output. Lastly, we + rewrite the output of integration back in terms of SingularityFunction. + + - If none of the above case arises, we return None. + + Examples + ======== + + >>> from sympy.integrals.singularityfunctions import singularityintegrate + >>> from sympy import SingularityFunction, symbols, Function + >>> x, a, n, y = symbols('x a n y') + >>> f = Function('f') + >>> singularityintegrate(SingularityFunction(x, a, 3), x) + SingularityFunction(x, a, 4)/4 + >>> singularityintegrate(5*SingularityFunction(x, 5, -2), x) + 5*SingularityFunction(x, 5, -1) + >>> singularityintegrate(6*SingularityFunction(x, 5, -1), x) + 6*SingularityFunction(x, 5, 0) + >>> singularityintegrate(x*SingularityFunction(x, 0, -1), x) + 0 + >>> singularityintegrate(SingularityFunction(x, 1, -1) * f(x), x) + f(1)*SingularityFunction(x, 1, 0) + + """ + + if not f.has(SingularityFunction): + return None + + if isinstance(f, SingularityFunction): + x, a, n = f.args + if n.is_positive or n.is_zero: + return SingularityFunction(x, a, n + 1)/(n + 1) + elif n in (-1, -2, -3, -4): + return SingularityFunction(x, a, n + 1) + + if f.is_Mul or f.is_Pow: + + expr = f.rewrite(DiracDelta) + expr = integrate(expr, x) + return expr.rewrite(SingularityFunction) + return None diff --git a/.venv/lib/python3.13/site-packages/sympy/integrals/transforms.py b/.venv/lib/python3.13/site-packages/sympy/integrals/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..19dbd990e4f0320e76707a1a2a8b1601fcd8a3df --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/integrals/transforms.py @@ -0,0 +1,1590 @@ +""" Integral Transforms """ +from functools import reduce, wraps +from itertools import repeat +from sympy.core import S, pi +from sympy.core.add import Add +from sympy.core.function import ( + AppliedUndef, count_ops, expand, expand_mul, Function) +from sympy.core.mul import Mul +from sympy.core.intfunc import igcd, ilcm +from sympy.core.sorting import default_sort_key +from sympy.core.symbol import Dummy +from sympy.core.traversal import postorder_traversal +from sympy.functions.combinatorial.factorials import factorial, rf +from sympy.functions.elementary.complexes import re, arg, Abs +from sympy.functions.elementary.exponential import exp, exp_polar +from sympy.functions.elementary.hyperbolic import cosh, coth, sinh, tanh +from sympy.functions.elementary.integers import ceiling +from sympy.functions.elementary.miscellaneous import Max, Min, sqrt +from sympy.functions.elementary.piecewise import piecewise_fold +from sympy.functions.elementary.trigonometric import cos, cot, sin, tan +from sympy.functions.special.bessel import besselj +from sympy.functions.special.delta_functions import Heaviside +from sympy.functions.special.gamma_functions import gamma +from sympy.functions.special.hyper import meijerg +from sympy.integrals import integrate, Integral +from sympy.integrals.meijerint import _dummy +from sympy.logic.boolalg import to_cnf, conjuncts, disjuncts, Or, And +from sympy.polys.polyroots import roots +from sympy.polys.polytools import factor, Poly +from sympy.polys.rootoftools import CRootOf +from sympy.utilities.iterables import iterable +from sympy.utilities.misc import debug + + +########################################################################## +# Helpers / Utilities +########################################################################## + + +class IntegralTransformError(NotImplementedError): + """ + Exception raised in relation to problems computing transforms. + + Explanation + =========== + + This class is mostly used internally; if integrals cannot be computed + objects representing unevaluated transforms are usually returned. + + The hint ``needeval=True`` can be used to disable returning transform + objects, and instead raise this exception if an integral cannot be + computed. + """ + def __init__(self, transform, function, msg): + super().__init__( + "%s Transform could not be computed: %s." % (transform, msg)) + self.function = function + + +class IntegralTransform(Function): + """ + Base class for integral transforms. + + Explanation + =========== + + This class represents unevaluated transforms. + + To implement a concrete transform, derive from this class and implement + the ``_compute_transform(f, x, s, **hints)`` and ``_as_integral(f, x, s)`` + functions. If the transform cannot be computed, raise :obj:`IntegralTransformError`. + + Also set ``cls._name``. For instance, + + >>> from sympy import LaplaceTransform + >>> LaplaceTransform._name + 'Laplace' + + Implement ``self._collapse_extra`` if your function returns more than just a + number and possibly a convergence condition. + """ + + @property + def function(self): + """ The function to be transformed. """ + return self.args[0] + + @property + def function_variable(self): + """ The dependent variable of the function to be transformed. """ + return self.args[1] + + @property + def transform_variable(self): + """ The independent transform variable. """ + return self.args[2] + + @property + def free_symbols(self): + """ + This method returns the symbols that will exist when the transform + is evaluated. + """ + return self.function.free_symbols.union({self.transform_variable}) \ + - {self.function_variable} + + def _compute_transform(self, f, x, s, **hints): + raise NotImplementedError + + def _as_integral(self, f, x, s): + raise NotImplementedError + + def _collapse_extra(self, extra): + cond = And(*extra) + if cond == False: + raise IntegralTransformError(self.__class__.name, None, '') + return cond + + def _try_directly(self, **hints): + T = None + try_directly = not any(func.has(self.function_variable) + for func in self.function.atoms(AppliedUndef)) + if try_directly: + try: + T = self._compute_transform(self.function, + self.function_variable, self.transform_variable, **hints) + except IntegralTransformError: + debug('[IT _try ] Caught IntegralTransformError, returns None') + T = None + + fn = self.function + if not fn.is_Add: + fn = expand_mul(fn) + return fn, T + + def doit(self, **hints): + """ + Try to evaluate the transform in closed form. + + Explanation + =========== + + This general function handles linearity, but apart from that leaves + pretty much everything to _compute_transform. + + Standard hints are the following: + + - ``simplify``: whether or not to simplify the result + - ``noconds``: if True, do not return convergence conditions + - ``needeval``: if True, raise IntegralTransformError instead of + returning IntegralTransform objects + + The default values of these hints depend on the concrete transform, + usually the default is + ``(simplify, noconds, needeval) = (True, False, False)``. + """ + needeval = hints.pop('needeval', False) + simplify = hints.pop('simplify', True) + hints['simplify'] = simplify + + fn, T = self._try_directly(**hints) + + if T is not None: + return T + + if fn.is_Add: + hints['needeval'] = needeval + res = [self.__class__(*([x] + list(self.args[1:]))).doit(**hints) + for x in fn.args] + extra = [] + ress = [] + for x in res: + if not isinstance(x, tuple): + x = [x] + ress.append(x[0]) + if len(x) == 2: + # only a condition + extra.append(x[1]) + elif len(x) > 2: + # some region parameters and a condition (Mellin, Laplace) + extra += [x[1:]] + if simplify==True: + res = Add(*ress).simplify() + else: + res = Add(*ress) + if not extra: + return res + try: + extra = self._collapse_extra(extra) + if iterable(extra): + return (res,) + tuple(extra) + else: + return (res, extra) + except IntegralTransformError: + pass + + if needeval: + raise IntegralTransformError( + self.__class__._name, self.function, 'needeval') + + # TODO handle derivatives etc + + # pull out constant coefficients + coeff, rest = fn.as_coeff_mul(self.function_variable) + return coeff*self.__class__(*([Mul(*rest)] + list(self.args[1:]))) + + @property + def as_integral(self): + return self._as_integral(self.function, self.function_variable, + self.transform_variable) + + def _eval_rewrite_as_Integral(self, *args, **kwargs): + return self.as_integral + + +def _simplify(expr, doit): + if doit: + from sympy.simplify import simplify + from sympy.simplify.powsimp import powdenest + return simplify(powdenest(piecewise_fold(expr), polar=True)) + return expr + + +def _noconds_(default): + """ + This is a decorator generator for dropping convergence conditions. + + Explanation + =========== + + Suppose you define a function ``transform(*args)`` which returns a tuple of + the form ``(result, cond1, cond2, ...)``. + + Decorating it ``@_noconds_(default)`` will add a new keyword argument + ``noconds`` to it. If ``noconds=True``, the return value will be altered to + be only ``result``, whereas if ``noconds=False`` the return value will not + be altered. + + The default value of the ``noconds`` keyword will be ``default`` (i.e. the + argument of this function). + """ + def make_wrapper(func): + @wraps(func) + def wrapper(*args, noconds=default, **kwargs): + res = func(*args, **kwargs) + if noconds: + return res[0] + return res + return wrapper + return make_wrapper +_noconds = _noconds_(False) + + +########################################################################## +# Mellin Transform +########################################################################## + +def _default_integrator(f, x): + return integrate(f, (x, S.Zero, S.Infinity)) + + +@_noconds +def _mellin_transform(f, x, s_, integrator=_default_integrator, simplify=True): + """ Backend function to compute Mellin transforms. """ + # We use a fresh dummy, because assumptions on s might drop conditions on + # convergence of the integral. + s = _dummy('s', 'mellin-transform', f) + F = integrator(x**(s - 1) * f, x) + + if not F.has(Integral): + return _simplify(F.subs(s, s_), simplify), (S.NegativeInfinity, S.Infinity), S.true + + if not F.is_Piecewise: # XXX can this work if integration gives continuous result now? + raise IntegralTransformError('Mellin', f, 'could not compute integral') + + F, cond = F.args[0] + if F.has(Integral): + raise IntegralTransformError( + 'Mellin', f, 'integral in unexpected form') + + def process_conds(cond): + """ + Turn ``cond`` into a strip (a, b), and auxiliary conditions. + """ + from sympy.solvers.inequalities import _solve_inequality + a = S.NegativeInfinity + b = S.Infinity + aux = S.true + conds = conjuncts(to_cnf(cond)) + t = Dummy('t', real=True) + for c in conds: + a_ = S.Infinity + b_ = S.NegativeInfinity + aux_ = [] + for d in disjuncts(c): + d_ = d.replace( + re, lambda x: x.as_real_imag()[0]).subs(re(s), t) + if not d.is_Relational or \ + d.rel_op in ('==', '!=') \ + or d_.has(s) or not d_.has(t): + aux_ += [d] + continue + soln = _solve_inequality(d_, t) + if not soln.is_Relational or \ + soln.rel_op in ('==', '!='): + aux_ += [d] + continue + if soln.lts == t: + b_ = Max(soln.gts, b_) + else: + a_ = Min(soln.lts, a_) + if a_ is not S.Infinity and a_ != b: + a = Max(a_, a) + elif b_ is not S.NegativeInfinity and b_ != a: + b = Min(b_, b) + else: + aux = And(aux, Or(*aux_)) + return a, b, aux + + conds = [process_conds(c) for c in disjuncts(cond)] + conds = [x for x in conds if x[2] != False] + conds.sort(key=lambda x: (x[0] - x[1], count_ops(x[2]))) + + if not conds: + raise IntegralTransformError('Mellin', f, 'no convergence found') + + a, b, aux = conds[0] + return _simplify(F.subs(s, s_), simplify), (a, b), aux + + +class MellinTransform(IntegralTransform): + """ + Class representing unevaluated Mellin transforms. + + For usage of this class, see the :class:`IntegralTransform` docstring. + + For how to compute Mellin transforms, see the :func:`mellin_transform` + docstring. + """ + + _name = 'Mellin' + + def _compute_transform(self, f, x, s, **hints): + return _mellin_transform(f, x, s, **hints) + + def _as_integral(self, f, x, s): + return Integral(f*x**(s - 1), (x, S.Zero, S.Infinity)) + + def _collapse_extra(self, extra): + a = [] + b = [] + cond = [] + for (sa, sb), c in extra: + a += [sa] + b += [sb] + cond += [c] + res = (Max(*a), Min(*b)), And(*cond) + if (res[0][0] >= res[0][1]) == True or res[1] == False: + raise IntegralTransformError( + 'Mellin', None, 'no combined convergence.') + return res + + +def mellin_transform(f, x, s, **hints): + r""" + Compute the Mellin transform `F(s)` of `f(x)`, + + .. math :: F(s) = \int_0^\infty x^{s-1} f(x) \mathrm{d}x. + + For all "sensible" functions, this converges absolutely in a strip + `a < \operatorname{Re}(s) < b`. + + Explanation + =========== + + The Mellin transform is related via change of variables to the Fourier + transform, and also to the (bilateral) Laplace transform. + + This function returns ``(F, (a, b), cond)`` + where ``F`` is the Mellin transform of ``f``, ``(a, b)`` is the fundamental strip + (as above), and ``cond`` are auxiliary convergence conditions. + + If the integral cannot be computed in closed form, this function returns + an unevaluated :class:`MellinTransform` object. + + For a description of possible hints, refer to the docstring of + :func:`sympy.integrals.transforms.IntegralTransform.doit`. If ``noconds=False``, + then only `F` will be returned (i.e. not ``cond``, and also not the strip + ``(a, b)``). + + Examples + ======== + + >>> from sympy import mellin_transform, exp + >>> from sympy.abc import x, s + >>> mellin_transform(exp(-x), x, s) + (gamma(s), (0, oo), True) + + See Also + ======== + + inverse_mellin_transform, laplace_transform, fourier_transform + hankel_transform, inverse_hankel_transform + """ + return MellinTransform(f, x, s).doit(**hints) + + +def _rewrite_sin(m_n, s, a, b): + """ + Re-write the sine function ``sin(m*s + n)`` as gamma functions, compatible + with the strip (a, b). + + Return ``(gamma1, gamma2, fac)`` so that ``f == fac/(gamma1 * gamma2)``. + + Examples + ======== + + >>> from sympy.integrals.transforms import _rewrite_sin + >>> from sympy import pi, S + >>> from sympy.abc import s + >>> _rewrite_sin((pi, 0), s, 0, 1) + (gamma(s), gamma(1 - s), pi) + >>> _rewrite_sin((pi, 0), s, 1, 0) + (gamma(s - 1), gamma(2 - s), -pi) + >>> _rewrite_sin((pi, 0), s, -1, 0) + (gamma(s + 1), gamma(-s), -pi) + >>> _rewrite_sin((pi, pi/2), s, S(1)/2, S(3)/2) + (gamma(s - 1/2), gamma(3/2 - s), -pi) + >>> _rewrite_sin((pi, pi), s, 0, 1) + (gamma(s), gamma(1 - s), -pi) + >>> _rewrite_sin((2*pi, 0), s, 0, S(1)/2) + (gamma(2*s), gamma(1 - 2*s), pi) + >>> _rewrite_sin((2*pi, 0), s, S(1)/2, 1) + (gamma(2*s - 1), gamma(2 - 2*s), -pi) + """ + # (This is a separate function because it is moderately complicated, + # and I want to doctest it.) + # We want to use pi/sin(pi*x) = gamma(x)*gamma(1-x). + # But there is one complication: the gamma functions determine the + # integration contour in the definition of the G-function. Usually + # it would not matter if this is slightly shifted, unless this way + # we create an undefined function! + # So we try to write this in such a way that the gammas are + # eminently on the right side of the strip. + m, n = m_n + + m = expand_mul(m/pi) + n = expand_mul(n/pi) + r = ceiling(-m*a - n.as_real_imag()[0]) # Don't use re(n), does not expand + return gamma(m*s + n + r), gamma(1 - n - r - m*s), (-1)**r*pi + + +class MellinTransformStripError(ValueError): + """ + Exception raised by _rewrite_gamma. Mainly for internal use. + """ + pass + + +def _rewrite_gamma(f, s, a, b): + """ + Try to rewrite the product f(s) as a product of gamma functions, + so that the inverse Mellin transform of f can be expressed as a meijer + G function. + + Explanation + =========== + + Return (an, ap), (bm, bq), arg, exp, fac such that + G((an, ap), (bm, bq), arg/z**exp)*fac is the inverse Mellin transform of f(s). + + Raises IntegralTransformError or MellinTransformStripError on failure. + + It is asserted that f has no poles in the fundamental strip designated by + (a, b). One of a and b is allowed to be None. The fundamental strip is + important, because it determines the inversion contour. + + This function can handle exponentials, linear factors, trigonometric + functions. + + This is a helper function for inverse_mellin_transform that will not + attempt any transformations on f. + + Examples + ======== + + >>> from sympy.integrals.transforms import _rewrite_gamma + >>> from sympy.abc import s + >>> from sympy import oo + >>> _rewrite_gamma(s*(s+3)*(s-1), s, -oo, oo) + (([], [-3, 0, 1]), ([-2, 1, 2], []), 1, 1, -1) + >>> _rewrite_gamma((s-1)**2, s, -oo, oo) + (([], [1, 1]), ([2, 2], []), 1, 1, 1) + + Importance of the fundamental strip: + + >>> _rewrite_gamma(1/s, s, 0, oo) + (([1], []), ([], [0]), 1, 1, 1) + >>> _rewrite_gamma(1/s, s, None, oo) + (([1], []), ([], [0]), 1, 1, 1) + >>> _rewrite_gamma(1/s, s, 0, None) + (([1], []), ([], [0]), 1, 1, 1) + >>> _rewrite_gamma(1/s, s, -oo, 0) + (([], [1]), ([0], []), 1, 1, -1) + >>> _rewrite_gamma(1/s, s, None, 0) + (([], [1]), ([0], []), 1, 1, -1) + >>> _rewrite_gamma(1/s, s, -oo, None) + (([], [1]), ([0], []), 1, 1, -1) + + >>> _rewrite_gamma(2**(-s+3), s, -oo, oo) + (([], []), ([], []), 1/2, 1, 8) + """ + # Our strategy will be as follows: + # 1) Guess a constant c such that the inversion integral should be + # performed wrt s'=c*s (instead of plain s). Write s for s'. + # 2) Process all factors, rewrite them independently as gamma functions in + # argument s, or exponentials of s. + # 3) Try to transform all gamma functions s.t. they have argument + # a+s or a-s. + # 4) Check that the resulting G function parameters are valid. + # 5) Combine all the exponentials. + + a_, b_ = S([a, b]) + + def left(c, is_numer): + """ + Decide whether pole at c lies to the left of the fundamental strip. + """ + # heuristically, this is the best chance for us to solve the inequalities + c = expand(re(c)) + if a_ is None and b_ is S.Infinity: + return True + if a_ is None: + return c < b_ + if b_ is None: + return c <= a_ + if (c >= b_) == True: + return False + if (c <= a_) == True: + return True + if is_numer: + return None + if a_.free_symbols or b_.free_symbols or c.free_symbols: + return None # XXX + #raise IntegralTransformError('Inverse Mellin', f, + # 'Could not determine position of singularity %s' + # ' relative to fundamental strip' % c) + raise MellinTransformStripError('Pole inside critical strip?') + + # 1) + s_multipliers = [] + for g in f.atoms(gamma): + if not g.has(s): + continue + arg = g.args[0] + if arg.is_Add: + arg = arg.as_independent(s)[1] + coeff, _ = arg.as_coeff_mul(s) + s_multipliers += [coeff] + for g in f.atoms(sin, cos, tan, cot): + if not g.has(s): + continue + arg = g.args[0] + if arg.is_Add: + arg = arg.as_independent(s)[1] + coeff, _ = arg.as_coeff_mul(s) + s_multipliers += [coeff/pi] + s_multipliers = [Abs(x) if x.is_extended_real else x for x in s_multipliers] + common_coefficient = S.One + for x in s_multipliers: + if not x.is_Rational: + common_coefficient = x + break + s_multipliers = [x/common_coefficient for x in s_multipliers] + if not (all(x.is_Rational for x in s_multipliers) and + common_coefficient.is_extended_real): + raise IntegralTransformError("Gamma", None, "Nonrational multiplier") + s_multiplier = common_coefficient/reduce(ilcm, [S(x.q) + for x in s_multipliers], S.One) + if s_multiplier == common_coefficient: + if len(s_multipliers) == 0: + s_multiplier = common_coefficient + else: + s_multiplier = common_coefficient \ + *reduce(igcd, [S(x.p) for x in s_multipliers]) + + f = f.subs(s, s/s_multiplier) + fac = S.One/s_multiplier + exponent = S.One/s_multiplier + if a_ is not None: + a_ *= s_multiplier + if b_ is not None: + b_ *= s_multiplier + + # 2) + numer, denom = f.as_numer_denom() + numer = Mul.make_args(numer) + denom = Mul.make_args(denom) + args = list(zip(numer, repeat(True))) + list(zip(denom, repeat(False))) + + facs = [] + dfacs = [] + # *_gammas will contain pairs (a, c) representing Gamma(a*s + c) + numer_gammas = [] + denom_gammas = [] + # exponentials will contain bases for exponentials of s + exponentials = [] + + def exception(fact): + return IntegralTransformError("Inverse Mellin", f, "Unrecognised form '%s'." % fact) + while args: + fact, is_numer = args.pop() + if is_numer: + ugammas, lgammas = numer_gammas, denom_gammas + ufacs = facs + else: + ugammas, lgammas = denom_gammas, numer_gammas + ufacs = dfacs + + def linear_arg(arg): + """ Test if arg is of form a*s+b, raise exception if not. """ + if not arg.is_polynomial(s): + raise exception(fact) + p = Poly(arg, s) + if p.degree() != 1: + raise exception(fact) + return p.all_coeffs() + + # constants + if not fact.has(s): + ufacs += [fact] + # exponentials + elif fact.is_Pow or isinstance(fact, exp): + if fact.is_Pow: + base = fact.base + exp_ = fact.exp + else: + base = exp_polar(1) + exp_ = fact.exp + if exp_.is_Integer: + cond = is_numer + if exp_ < 0: + cond = not cond + args += [(base, cond)]*Abs(exp_) + continue + elif not base.has(s): + a, b = linear_arg(exp_) + if not is_numer: + base = 1/base + exponentials += [base**a] + facs += [base**b] + else: + raise exception(fact) + # linear factors + elif fact.is_polynomial(s): + p = Poly(fact, s) + if p.degree() != 1: + # We completely factor the poly. For this we need the roots. + # Now roots() only works in some cases (low degree), and CRootOf + # only works without parameters. So try both... + coeff = p.LT()[1] + rs = roots(p, s) + if len(rs) != p.degree(): + rs = CRootOf.all_roots(p) + ufacs += [coeff] + args += [(s - c, is_numer) for c in rs] + continue + a, c = p.all_coeffs() + ufacs += [a] + c /= -a + # Now need to convert s - c + if left(c, is_numer): + ugammas += [(S.One, -c + 1)] + lgammas += [(S.One, -c)] + else: + ufacs += [-1] + ugammas += [(S.NegativeOne, c + 1)] + lgammas += [(S.NegativeOne, c)] + elif isinstance(fact, gamma): + a, b = linear_arg(fact.args[0]) + if is_numer: + if (a > 0 and (left(-b/a, is_numer) == False)) or \ + (a < 0 and (left(-b/a, is_numer) == True)): + raise NotImplementedError( + 'Gammas partially over the strip.') + ugammas += [(a, b)] + elif isinstance(fact, sin): + # We try to re-write all trigs as gammas. This is not in + # general the best strategy, since sometimes this is impossible, + # but rewriting as exponentials would work. However trig functions + # in inverse mellin transforms usually all come from simplifying + # gamma terms, so this should work. + a = fact.args[0] + if is_numer: + # No problem with the poles. + gamma1, gamma2, fac_ = gamma(a/pi), gamma(1 - a/pi), pi + else: + gamma1, gamma2, fac_ = _rewrite_sin(linear_arg(a), s, a_, b_) + args += [(gamma1, not is_numer), (gamma2, not is_numer)] + ufacs += [fac_] + elif isinstance(fact, tan): + a = fact.args[0] + args += [(sin(a, evaluate=False), is_numer), + (sin(pi/2 - a, evaluate=False), not is_numer)] + elif isinstance(fact, cos): + a = fact.args[0] + args += [(sin(pi/2 - a, evaluate=False), is_numer)] + elif isinstance(fact, cot): + a = fact.args[0] + args += [(sin(pi/2 - a, evaluate=False), is_numer), + (sin(a, evaluate=False), not is_numer)] + else: + raise exception(fact) + + fac *= Mul(*facs)/Mul(*dfacs) + + # 3) + an, ap, bm, bq = [], [], [], [] + for gammas, plus, minus, is_numer in [(numer_gammas, an, bm, True), + (denom_gammas, bq, ap, False)]: + while gammas: + a, c = gammas.pop() + if a != -1 and a != +1: + # We use the gamma function multiplication theorem. + p = Abs(S(a)) + newa = a/p + newc = c/p + if not a.is_Integer: + raise TypeError("a is not an integer") + for k in range(p): + gammas += [(newa, newc + k/p)] + if is_numer: + fac *= (2*pi)**((1 - p)/2) * p**(c - S.Half) + exponentials += [p**a] + else: + fac /= (2*pi)**((1 - p)/2) * p**(c - S.Half) + exponentials += [p**(-a)] + continue + if a == +1: + plus.append(1 - c) + else: + minus.append(c) + + # 4) + # TODO + + # 5) + arg = Mul(*exponentials) + + # for testability, sort the arguments + an.sort(key=default_sort_key) + ap.sort(key=default_sort_key) + bm.sort(key=default_sort_key) + bq.sort(key=default_sort_key) + + return (an, ap), (bm, bq), arg, exponent, fac + + +@_noconds_(True) +def _inverse_mellin_transform(F, s, x_, strip, as_meijerg=False): + """ A helper for the real inverse_mellin_transform function, this one here + assumes x to be real and positive. """ + x = _dummy('t', 'inverse-mellin-transform', F, positive=True) + # Actually, we won't try integration at all. Instead we use the definition + # of the Meijer G function as a fairly general inverse mellin transform. + F = F.rewrite(gamma) + for g in [factor(F), expand_mul(F), expand(F)]: + if g.is_Add: + # do all terms separately + ress = [_inverse_mellin_transform(G, s, x, strip, as_meijerg, + noconds=False) + for G in g.args] + conds = [p[1] for p in ress] + ress = [p[0] for p in ress] + res = Add(*ress) + if not as_meijerg: + res = factor(res, gens=res.atoms(Heaviside)) + return res.subs(x, x_), And(*conds) + + try: + a, b, C, e, fac = _rewrite_gamma(g, s, strip[0], strip[1]) + except IntegralTransformError: + continue + try: + G = meijerg(a, b, C/x**e) + except ValueError: + continue + if as_meijerg: + h = G + else: + try: + from sympy.simplify import hyperexpand + h = hyperexpand(G) + except NotImplementedError: + raise IntegralTransformError( + 'Inverse Mellin', F, 'Could not calculate integral') + + if h.is_Piecewise and len(h.args) == 3: + # XXX we break modularity here! + h = Heaviside(x - Abs(C))*h.args[0].args[0] \ + + Heaviside(Abs(C) - x)*h.args[1].args[0] + # We must ensure that the integral along the line we want converges, + # and return that value. + # See [L], 5.2 + cond = [Abs(arg(G.argument)) < G.delta*pi] + # Note: we allow ">=" here, this corresponds to convergence if we let + # limits go to oo symmetrically. ">" corresponds to absolute convergence. + cond += [And(Or(len(G.ap) != len(G.bq), 0 >= re(G.nu) + 1), + Abs(arg(G.argument)) == G.delta*pi)] + cond = Or(*cond) + if cond == False: + raise IntegralTransformError( + 'Inverse Mellin', F, 'does not converge') + return (h*fac).subs(x, x_), cond + + raise IntegralTransformError('Inverse Mellin', F, '') + +_allowed = None + + +class InverseMellinTransform(IntegralTransform): + """ + Class representing unevaluated inverse Mellin transforms. + + For usage of this class, see the :class:`IntegralTransform` docstring. + + For how to compute inverse Mellin transforms, see the + :func:`inverse_mellin_transform` docstring. + """ + + _name = 'Inverse Mellin' + _none_sentinel = Dummy('None') + _c = Dummy('c') + + def __new__(cls, F, s, x, a, b, **opts): + if a is None: + a = InverseMellinTransform._none_sentinel + if b is None: + b = InverseMellinTransform._none_sentinel + return IntegralTransform.__new__(cls, F, s, x, a, b, **opts) + + @property + def fundamental_strip(self): + a, b = self.args[3], self.args[4] + if a is InverseMellinTransform._none_sentinel: + a = None + if b is InverseMellinTransform._none_sentinel: + b = None + return a, b + + def _compute_transform(self, F, s, x, **hints): + # IntegralTransform's doit will cause this hint to exist, but + # InverseMellinTransform should ignore it + hints.pop('simplify', True) + global _allowed + if _allowed is None: + _allowed = { + exp, gamma, sin, cos, tan, cot, cosh, sinh, tanh, coth, + factorial, rf} + for f in postorder_traversal(F): + if f.is_Function and f.has(s) and f.func not in _allowed: + raise IntegralTransformError('Inverse Mellin', F, + 'Component %s not recognised.' % f) + strip = self.fundamental_strip + return _inverse_mellin_transform(F, s, x, strip, **hints) + + def _as_integral(self, F, s, x): + c = self.__class__._c + return Integral(F*x**(-s), (s, c - S.ImaginaryUnit*S.Infinity, c + + S.ImaginaryUnit*S.Infinity))/(2*S.Pi*S.ImaginaryUnit) + + +def inverse_mellin_transform(F, s, x, strip, **hints): + r""" + Compute the inverse Mellin transform of `F(s)` over the fundamental + strip given by ``strip=(a, b)``. + + Explanation + =========== + + This can be defined as + + .. math:: f(x) = \frac{1}{2\pi i} \int_{c - i\infty}^{c + i\infty} x^{-s} F(s) \mathrm{d}s, + + for any `c` in the fundamental strip. Under certain regularity + conditions on `F` and/or `f`, + this recovers `f` from its Mellin transform `F` + (and vice versa), for positive real `x`. + + One of `a` or `b` may be passed as ``None``; a suitable `c` will be + inferred. + + If the integral cannot be computed in closed form, this function returns + an unevaluated :class:`InverseMellinTransform` object. + + Note that this function will assume x to be positive and real, regardless + of the SymPy assumptions! + + For a description of possible hints, refer to the docstring of + :func:`sympy.integrals.transforms.IntegralTransform.doit`. + + Examples + ======== + + >>> from sympy import inverse_mellin_transform, oo, gamma + >>> from sympy.abc import x, s + >>> inverse_mellin_transform(gamma(s), s, x, (0, oo)) + exp(-x) + + The fundamental strip matters: + + >>> f = 1/(s**2 - 1) + >>> inverse_mellin_transform(f, s, x, (-oo, -1)) + x*(1 - 1/x**2)*Heaviside(x - 1)/2 + >>> inverse_mellin_transform(f, s, x, (-1, 1)) + -x*Heaviside(1 - x)/2 - Heaviside(x - 1)/(2*x) + >>> inverse_mellin_transform(f, s, x, (1, oo)) + (1/2 - x**2/2)*Heaviside(1 - x)/x + + See Also + ======== + + mellin_transform + hankel_transform, inverse_hankel_transform + """ + return InverseMellinTransform(F, s, x, strip[0], strip[1]).doit(**hints) + + +########################################################################## +# Fourier Transform +########################################################################## + +@_noconds_(True) +def _fourier_transform(f, x, k, a, b, name, simplify=True): + r""" + Compute a general Fourier-type transform + + .. math:: + + F(k) = a \int_{-\infty}^{\infty} e^{bixk} f(x)\, dx. + + For suitable choice of *a* and *b*, this reduces to the standard Fourier + and inverse Fourier transforms. + """ + F = integrate(a*f*exp(b*S.ImaginaryUnit*x*k), (x, S.NegativeInfinity, S.Infinity)) + + if not F.has(Integral): + return _simplify(F, simplify), S.true + + integral_f = integrate(f, (x, S.NegativeInfinity, S.Infinity)) + if integral_f in (S.NegativeInfinity, S.Infinity, S.NaN) or integral_f.has(Integral): + raise IntegralTransformError(name, f, 'function not integrable on real axis') + + if not F.is_Piecewise: + raise IntegralTransformError(name, f, 'could not compute integral') + + F, cond = F.args[0] + if F.has(Integral): + raise IntegralTransformError(name, f, 'integral in unexpected form') + + return _simplify(F, simplify), cond + + +class FourierTypeTransform(IntegralTransform): + """ Base class for Fourier transforms.""" + + def a(self): + raise NotImplementedError( + "Class %s must implement a(self) but does not" % self.__class__) + + def b(self): + raise NotImplementedError( + "Class %s must implement b(self) but does not" % self.__class__) + + def _compute_transform(self, f, x, k, **hints): + return _fourier_transform(f, x, k, + self.a(), self.b(), + self.__class__._name, **hints) + + def _as_integral(self, f, x, k): + a = self.a() + b = self.b() + return Integral(a*f*exp(b*S.ImaginaryUnit*x*k), (x, S.NegativeInfinity, S.Infinity)) + + +class FourierTransform(FourierTypeTransform): + """ + Class representing unevaluated Fourier transforms. + + For usage of this class, see the :class:`IntegralTransform` docstring. + + For how to compute Fourier transforms, see the :func:`fourier_transform` + docstring. + """ + + _name = 'Fourier' + + def a(self): + return 1 + + def b(self): + return -2*S.Pi + + +def fourier_transform(f, x, k, **hints): + r""" + Compute the unitary, ordinary-frequency Fourier transform of ``f``, defined + as + + .. math:: F(k) = \int_{-\infty}^\infty f(x) e^{-2\pi i x k} \mathrm{d} x. + + Explanation + =========== + + If the transform cannot be computed in closed form, this + function returns an unevaluated :class:`FourierTransform` object. + + For other Fourier transform conventions, see the function + :func:`sympy.integrals.transforms._fourier_transform`. + + For a description of possible hints, refer to the docstring of + :func:`sympy.integrals.transforms.IntegralTransform.doit`. + Note that for this transform, by default ``noconds=True``. + + Examples + ======== + + >>> from sympy import fourier_transform, exp + >>> from sympy.abc import x, k + >>> fourier_transform(exp(-x**2), x, k) + sqrt(pi)*exp(-pi**2*k**2) + >>> fourier_transform(exp(-x**2), x, k, noconds=False) + (sqrt(pi)*exp(-pi**2*k**2), True) + + See Also + ======== + + inverse_fourier_transform + sine_transform, inverse_sine_transform + cosine_transform, inverse_cosine_transform + hankel_transform, inverse_hankel_transform + mellin_transform, laplace_transform + """ + return FourierTransform(f, x, k).doit(**hints) + + +class InverseFourierTransform(FourierTypeTransform): + """ + Class representing unevaluated inverse Fourier transforms. + + For usage of this class, see the :class:`IntegralTransform` docstring. + + For how to compute inverse Fourier transforms, see the + :func:`inverse_fourier_transform` docstring. + """ + + _name = 'Inverse Fourier' + + def a(self): + return 1 + + def b(self): + return 2*S.Pi + + +def inverse_fourier_transform(F, k, x, **hints): + r""" + Compute the unitary, ordinary-frequency inverse Fourier transform of `F`, + defined as + + .. math:: f(x) = \int_{-\infty}^\infty F(k) e^{2\pi i x k} \mathrm{d} k. + + Explanation + =========== + + If the transform cannot be computed in closed form, this + function returns an unevaluated :class:`InverseFourierTransform` object. + + For other Fourier transform conventions, see the function + :func:`sympy.integrals.transforms._fourier_transform`. + + For a description of possible hints, refer to the docstring of + :func:`sympy.integrals.transforms.IntegralTransform.doit`. + Note that for this transform, by default ``noconds=True``. + + Examples + ======== + + >>> from sympy import inverse_fourier_transform, exp, sqrt, pi + >>> from sympy.abc import x, k + >>> inverse_fourier_transform(sqrt(pi)*exp(-(pi*k)**2), k, x) + exp(-x**2) + >>> inverse_fourier_transform(sqrt(pi)*exp(-(pi*k)**2), k, x, noconds=False) + (exp(-x**2), True) + + See Also + ======== + + fourier_transform + sine_transform, inverse_sine_transform + cosine_transform, inverse_cosine_transform + hankel_transform, inverse_hankel_transform + mellin_transform, laplace_transform + """ + return InverseFourierTransform(F, k, x).doit(**hints) + + +########################################################################## +# Fourier Sine and Cosine Transform +########################################################################## + +@_noconds_(True) +def _sine_cosine_transform(f, x, k, a, b, K, name, simplify=True): + """ + Compute a general sine or cosine-type transform + F(k) = a int_0^oo b*sin(x*k) f(x) dx. + F(k) = a int_0^oo b*cos(x*k) f(x) dx. + + For suitable choice of a and b, this reduces to the standard sine/cosine + and inverse sine/cosine transforms. + """ + F = integrate(a*f*K(b*x*k), (x, S.Zero, S.Infinity)) + + if not F.has(Integral): + return _simplify(F, simplify), S.true + + if not F.is_Piecewise: + raise IntegralTransformError(name, f, 'could not compute integral') + + F, cond = F.args[0] + if F.has(Integral): + raise IntegralTransformError(name, f, 'integral in unexpected form') + + return _simplify(F, simplify), cond + + +class SineCosineTypeTransform(IntegralTransform): + """ + Base class for sine and cosine transforms. + Specify cls._kern. + """ + + def a(self): + raise NotImplementedError( + "Class %s must implement a(self) but does not" % self.__class__) + + def b(self): + raise NotImplementedError( + "Class %s must implement b(self) but does not" % self.__class__) + + + def _compute_transform(self, f, x, k, **hints): + return _sine_cosine_transform(f, x, k, + self.a(), self.b(), + self.__class__._kern, + self.__class__._name, **hints) + + def _as_integral(self, f, x, k): + a = self.a() + b = self.b() + K = self.__class__._kern + return Integral(a*f*K(b*x*k), (x, S.Zero, S.Infinity)) + + +class SineTransform(SineCosineTypeTransform): + """ + Class representing unevaluated sine transforms. + + For usage of this class, see the :class:`IntegralTransform` docstring. + + For how to compute sine transforms, see the :func:`sine_transform` + docstring. + """ + + _name = 'Sine' + _kern = sin + + def a(self): + return sqrt(2)/sqrt(pi) + + def b(self): + return S.One + + +def sine_transform(f, x, k, **hints): + r""" + Compute the unitary, ordinary-frequency sine transform of `f`, defined + as + + .. math:: F(k) = \sqrt{\frac{2}{\pi}} \int_{0}^\infty f(x) \sin(2\pi x k) \mathrm{d} x. + + Explanation + =========== + + If the transform cannot be computed in closed form, this + function returns an unevaluated :class:`SineTransform` object. + + For a description of possible hints, refer to the docstring of + :func:`sympy.integrals.transforms.IntegralTransform.doit`. + Note that for this transform, by default ``noconds=True``. + + Examples + ======== + + >>> from sympy import sine_transform, exp + >>> from sympy.abc import x, k, a + >>> sine_transform(x*exp(-a*x**2), x, k) + sqrt(2)*k*exp(-k**2/(4*a))/(4*a**(3/2)) + >>> sine_transform(x**(-a), x, k) + 2**(1/2 - a)*k**(a - 1)*gamma(1 - a/2)/gamma(a/2 + 1/2) + + See Also + ======== + + fourier_transform, inverse_fourier_transform + inverse_sine_transform + cosine_transform, inverse_cosine_transform + hankel_transform, inverse_hankel_transform + mellin_transform, laplace_transform + """ + return SineTransform(f, x, k).doit(**hints) + + +class InverseSineTransform(SineCosineTypeTransform): + """ + Class representing unevaluated inverse sine transforms. + + For usage of this class, see the :class:`IntegralTransform` docstring. + + For how to compute inverse sine transforms, see the + :func:`inverse_sine_transform` docstring. + """ + + _name = 'Inverse Sine' + _kern = sin + + def a(self): + return sqrt(2)/sqrt(pi) + + def b(self): + return S.One + + +def inverse_sine_transform(F, k, x, **hints): + r""" + Compute the unitary, ordinary-frequency inverse sine transform of `F`, + defined as + + .. math:: f(x) = \sqrt{\frac{2}{\pi}} \int_{0}^\infty F(k) \sin(2\pi x k) \mathrm{d} k. + + Explanation + =========== + + If the transform cannot be computed in closed form, this + function returns an unevaluated :class:`InverseSineTransform` object. + + For a description of possible hints, refer to the docstring of + :func:`sympy.integrals.transforms.IntegralTransform.doit`. + Note that for this transform, by default ``noconds=True``. + + Examples + ======== + + >>> from sympy import inverse_sine_transform, exp, sqrt, gamma + >>> from sympy.abc import x, k, a + >>> inverse_sine_transform(2**((1-2*a)/2)*k**(a - 1)* + ... gamma(-a/2 + 1)/gamma((a+1)/2), k, x) + x**(-a) + >>> inverse_sine_transform(sqrt(2)*k*exp(-k**2/(4*a))/(4*sqrt(a)**3), k, x) + x*exp(-a*x**2) + + See Also + ======== + + fourier_transform, inverse_fourier_transform + sine_transform + cosine_transform, inverse_cosine_transform + hankel_transform, inverse_hankel_transform + mellin_transform, laplace_transform + """ + return InverseSineTransform(F, k, x).doit(**hints) + + +class CosineTransform(SineCosineTypeTransform): + """ + Class representing unevaluated cosine transforms. + + For usage of this class, see the :class:`IntegralTransform` docstring. + + For how to compute cosine transforms, see the :func:`cosine_transform` + docstring. + """ + + _name = 'Cosine' + _kern = cos + + def a(self): + return sqrt(2)/sqrt(pi) + + def b(self): + return S.One + + +def cosine_transform(f, x, k, **hints): + r""" + Compute the unitary, ordinary-frequency cosine transform of `f`, defined + as + + .. math:: F(k) = \sqrt{\frac{2}{\pi}} \int_{0}^\infty f(x) \cos(2\pi x k) \mathrm{d} x. + + Explanation + =========== + + If the transform cannot be computed in closed form, this + function returns an unevaluated :class:`CosineTransform` object. + + For a description of possible hints, refer to the docstring of + :func:`sympy.integrals.transforms.IntegralTransform.doit`. + Note that for this transform, by default ``noconds=True``. + + Examples + ======== + + >>> from sympy import cosine_transform, exp, sqrt, cos + >>> from sympy.abc import x, k, a + >>> cosine_transform(exp(-a*x), x, k) + sqrt(2)*a/(sqrt(pi)*(a**2 + k**2)) + >>> cosine_transform(exp(-a*sqrt(x))*cos(a*sqrt(x)), x, k) + a*exp(-a**2/(2*k))/(2*k**(3/2)) + + See Also + ======== + + fourier_transform, inverse_fourier_transform, + sine_transform, inverse_sine_transform + inverse_cosine_transform + hankel_transform, inverse_hankel_transform + mellin_transform, laplace_transform + """ + return CosineTransform(f, x, k).doit(**hints) + + +class InverseCosineTransform(SineCosineTypeTransform): + """ + Class representing unevaluated inverse cosine transforms. + + For usage of this class, see the :class:`IntegralTransform` docstring. + + For how to compute inverse cosine transforms, see the + :func:`inverse_cosine_transform` docstring. + """ + + _name = 'Inverse Cosine' + _kern = cos + + def a(self): + return sqrt(2)/sqrt(pi) + + def b(self): + return S.One + + +def inverse_cosine_transform(F, k, x, **hints): + r""" + Compute the unitary, ordinary-frequency inverse cosine transform of `F`, + defined as + + .. math:: f(x) = \sqrt{\frac{2}{\pi}} \int_{0}^\infty F(k) \cos(2\pi x k) \mathrm{d} k. + + Explanation + =========== + + If the transform cannot be computed in closed form, this + function returns an unevaluated :class:`InverseCosineTransform` object. + + For a description of possible hints, refer to the docstring of + :func:`sympy.integrals.transforms.IntegralTransform.doit`. + Note that for this transform, by default ``noconds=True``. + + Examples + ======== + + >>> from sympy import inverse_cosine_transform, sqrt, pi + >>> from sympy.abc import x, k, a + >>> inverse_cosine_transform(sqrt(2)*a/(sqrt(pi)*(a**2 + k**2)), k, x) + exp(-a*x) + >>> inverse_cosine_transform(1/sqrt(k), k, x) + 1/sqrt(x) + + See Also + ======== + + fourier_transform, inverse_fourier_transform, + sine_transform, inverse_sine_transform + cosine_transform + hankel_transform, inverse_hankel_transform + mellin_transform, laplace_transform + """ + return InverseCosineTransform(F, k, x).doit(**hints) + + +########################################################################## +# Hankel Transform +########################################################################## + +@_noconds_(True) +def _hankel_transform(f, r, k, nu, name, simplify=True): + r""" + Compute a general Hankel transform + + .. math:: F_\nu(k) = \int_{0}^\infty f(r) J_\nu(k r) r \mathrm{d} r. + """ + F = integrate(f*besselj(nu, k*r)*r, (r, S.Zero, S.Infinity)) + + if not F.has(Integral): + return _simplify(F, simplify), S.true + + if not F.is_Piecewise: + raise IntegralTransformError(name, f, 'could not compute integral') + + F, cond = F.args[0] + if F.has(Integral): + raise IntegralTransformError(name, f, 'integral in unexpected form') + + return _simplify(F, simplify), cond + + +class HankelTypeTransform(IntegralTransform): + """ + Base class for Hankel transforms. + """ + + def doit(self, **hints): + return self._compute_transform(self.function, + self.function_variable, + self.transform_variable, + self.args[3], + **hints) + + def _compute_transform(self, f, r, k, nu, **hints): + return _hankel_transform(f, r, k, nu, self._name, **hints) + + def _as_integral(self, f, r, k, nu): + return Integral(f*besselj(nu, k*r)*r, (r, S.Zero, S.Infinity)) + + @property + def as_integral(self): + return self._as_integral(self.function, + self.function_variable, + self.transform_variable, + self.args[3]) + + +class HankelTransform(HankelTypeTransform): + """ + Class representing unevaluated Hankel transforms. + + For usage of this class, see the :class:`IntegralTransform` docstring. + + For how to compute Hankel transforms, see the :func:`hankel_transform` + docstring. + """ + + _name = 'Hankel' + + +def hankel_transform(f, r, k, nu, **hints): + r""" + Compute the Hankel transform of `f`, defined as + + .. math:: F_\nu(k) = \int_{0}^\infty f(r) J_\nu(k r) r \mathrm{d} r. + + Explanation + =========== + + If the transform cannot be computed in closed form, this + function returns an unevaluated :class:`HankelTransform` object. + + For a description of possible hints, refer to the docstring of + :func:`sympy.integrals.transforms.IntegralTransform.doit`. + Note that for this transform, by default ``noconds=True``. + + Examples + ======== + + >>> from sympy import hankel_transform, inverse_hankel_transform + >>> from sympy import exp + >>> from sympy.abc import r, k, m, nu, a + + >>> ht = hankel_transform(1/r**m, r, k, nu) + >>> ht + 2*k**(m - 2)*gamma(-m/2 + nu/2 + 1)/(2**m*gamma(m/2 + nu/2)) + + >>> inverse_hankel_transform(ht, k, r, nu) + r**(-m) + + >>> ht = hankel_transform(exp(-a*r), r, k, 0) + >>> ht + a/(k**3*(a**2/k**2 + 1)**(3/2)) + + >>> inverse_hankel_transform(ht, k, r, 0) + exp(-a*r) + + See Also + ======== + + fourier_transform, inverse_fourier_transform + sine_transform, inverse_sine_transform + cosine_transform, inverse_cosine_transform + inverse_hankel_transform + mellin_transform, laplace_transform + """ + return HankelTransform(f, r, k, nu).doit(**hints) + + +class InverseHankelTransform(HankelTypeTransform): + """ + Class representing unevaluated inverse Hankel transforms. + + For usage of this class, see the :class:`IntegralTransform` docstring. + + For how to compute inverse Hankel transforms, see the + :func:`inverse_hankel_transform` docstring. + """ + + _name = 'Inverse Hankel' + + +def inverse_hankel_transform(F, k, r, nu, **hints): + r""" + Compute the inverse Hankel transform of `F` defined as + + .. math:: f(r) = \int_{0}^\infty F_\nu(k) J_\nu(k r) k \mathrm{d} k. + + Explanation + =========== + + If the transform cannot be computed in closed form, this + function returns an unevaluated :class:`InverseHankelTransform` object. + + For a description of possible hints, refer to the docstring of + :func:`sympy.integrals.transforms.IntegralTransform.doit`. + Note that for this transform, by default ``noconds=True``. + + Examples + ======== + + >>> from sympy import hankel_transform, inverse_hankel_transform + >>> from sympy import exp + >>> from sympy.abc import r, k, m, nu, a + + >>> ht = hankel_transform(1/r**m, r, k, nu) + >>> ht + 2*k**(m - 2)*gamma(-m/2 + nu/2 + 1)/(2**m*gamma(m/2 + nu/2)) + + >>> inverse_hankel_transform(ht, k, r, nu) + r**(-m) + + >>> ht = hankel_transform(exp(-a*r), r, k, 0) + >>> ht + a/(k**3*(a**2/k**2 + 1)**(3/2)) + + >>> inverse_hankel_transform(ht, k, r, 0) + exp(-a*r) + + See Also + ======== + + fourier_transform, inverse_fourier_transform + sine_transform, inverse_sine_transform + cosine_transform, inverse_cosine_transform + hankel_transform + mellin_transform, laplace_transform + """ + return InverseHankelTransform(F, k, r, nu).doit(**hints) + + +########################################################################## +# Laplace Transform +########################################################################## + +# Stub classes and functions that used to be here +import sympy.integrals.laplace as _laplace + +LaplaceTransform = _laplace.LaplaceTransform +laplace_transform = _laplace.laplace_transform +laplace_correspondence = _laplace.laplace_correspondence +laplace_initial_conds = _laplace.laplace_initial_conds +InverseLaplaceTransform = _laplace.InverseLaplaceTransform +inverse_laplace_transform = _laplace.inverse_laplace_transform diff --git a/.venv/lib/python3.13/site-packages/sympy/integrals/trigonometry.py b/.venv/lib/python3.13/site-packages/sympy/integrals/trigonometry.py new file mode 100644 index 0000000000000000000000000000000000000000..dd6389bcc79f28ed6c255546685da1a0e061c327 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/integrals/trigonometry.py @@ -0,0 +1,335 @@ +from sympy.core import cacheit, Dummy, Ne, Integer, Rational, S, Wild +from sympy.functions import binomial, sin, cos, Piecewise, Abs +from .integrals import integrate + +# TODO sin(a*x)*cos(b*x) -> sin((a+b)x) + sin((a-b)x) ? + +# creating, each time, Wild's and sin/cos/Mul is expensive. Also, our match & +# subs are very slow when not cached, and if we create Wild each time, we +# effectively block caching. +# +# so we cache the pattern + +# need to use a function instead of lamda since hash of lambda changes on +# each call to _pat_sincos +def _integer_instance(n): + return isinstance(n, Integer) + +@cacheit +def _pat_sincos(x): + a = Wild('a', exclude=[x]) + n, m = [Wild(s, exclude=[x], properties=[_integer_instance]) + for s in 'nm'] + pat = sin(a*x)**n * cos(a*x)**m + return pat, a, n, m + +_u = Dummy('u') + + +def trigintegrate(f, x, conds='piecewise'): + """ + Integrate f = Mul(trig) over x. + + Examples + ======== + + >>> from sympy import sin, cos, tan, sec + >>> from sympy.integrals.trigonometry import trigintegrate + >>> from sympy.abc import x + + >>> trigintegrate(sin(x)*cos(x), x) + sin(x)**2/2 + + >>> trigintegrate(sin(x)**2, x) + x/2 - sin(x)*cos(x)/2 + + >>> trigintegrate(tan(x)*sec(x), x) + 1/cos(x) + + >>> trigintegrate(sin(x)*tan(x), x) + -log(sin(x) - 1)/2 + log(sin(x) + 1)/2 - sin(x) + + References + ========== + + .. [1] https://en.wikibooks.org/wiki/Calculus/Integration_techniques + + See Also + ======== + + sympy.integrals.integrals.Integral.doit + sympy.integrals.integrals.Integral + """ + pat, a, n, m = _pat_sincos(x) + + f = f.rewrite('sincos') + M = f.match(pat) + + if M is None: + return + + n, m = M[n], M[m] + if n.is_zero and m.is_zero: + return x + zz = x if n.is_zero else S.Zero + + a = M[a] + + if n.is_odd or m.is_odd: + u = _u + n_, m_ = n.is_odd, m.is_odd + + # take smallest n or m -- to choose simplest substitution + if n_ and m_: + + # Make sure to choose the positive one + # otherwise an incorrect integral can occur. + if n < 0 and m > 0: + m_ = True + n_ = False + elif m < 0 and n > 0: + n_ = True + m_ = False + # Both are negative so choose the smallest n or m + # in absolute value for simplest substitution. + elif (n < 0 and m < 0): + n_ = n > m + m_ = not (n > m) + + # Both n and m are odd and positive + else: + n_ = (n < m) # NB: careful here, one of the + m_ = not (n < m) # conditions *must* be true + + # n m u=C (n-1)/2 m + # S(x) * C(x) dx --> -(1-u^2) * u du + if n_: + ff = -(1 - u**2)**((n - 1)/2) * u**m + uu = cos(a*x) + + # n m u=S n (m-1)/2 + # S(x) * C(x) dx --> u * (1-u^2) du + elif m_: + ff = u**n * (1 - u**2)**((m - 1)/2) + uu = sin(a*x) + + fi = integrate(ff, u) # XXX cyclic deps + fx = fi.subs(u, uu) + if conds == 'piecewise': + return Piecewise((fx / a, Ne(a, 0)), (zz, True)) + return fx / a + + # n & m are both even + # + # 2k 2m 2l 2l + # we transform S (x) * C (x) into terms with only S (x) or C (x) + # + # example: + # 100 4 100 2 2 100 4 2 + # S (x) * C (x) = S (x) * (1-S (x)) = S (x) * (1 + S (x) - 2*S (x)) + # + # 104 102 100 + # = S (x) - 2*S (x) + S (x) + # 2k + # then S is integrated with recursive formula + + # take largest n or m -- to choose simplest substitution + n_ = (Abs(n) > Abs(m)) + m_ = (Abs(m) > Abs(n)) + res = S.Zero + + if n_: + # 2k 2 k i 2i + # C = (1 - S ) = sum(i, (-) * B(k, i) * S ) + if m > 0: + for i in range(0, m//2 + 1): + res += (S.NegativeOne**i * binomial(m//2, i) * + _sin_pow_integrate(n + 2*i, x)) + + elif m == 0: + res = _sin_pow_integrate(n, x) + else: + + # m < 0 , |n| > |m| + # / + # | + # | m n + # | cos (x) sin (x) dx = + # | + # | + #/ + # / + # | + # -1 m+1 n-1 n - 1 | m+2 n-2 + # ________ cos (x) sin (x) + _______ | cos (x) sin (x) dx + # | + # m + 1 m + 1 | + # / + + res = (Rational(-1, m + 1) * cos(x)**(m + 1) * sin(x)**(n - 1) + + Rational(n - 1, m + 1) * + trigintegrate(cos(x)**(m + 2)*sin(x)**(n - 2), x)) + + elif m_: + # 2k 2 k i 2i + # S = (1 - C ) = sum(i, (-) * B(k, i) * C ) + if n > 0: + + # / / + # | | + # | m n | -m n + # | cos (x)*sin (x) dx or | cos (x) * sin (x) dx + # | | + # / / + # + # |m| > |n| ; m, n >0 ; m, n belong to Z - {0} + # n 2 + # sin (x) term is expanded here in terms of cos (x), + # and then integrated. + # + + for i in range(0, n//2 + 1): + res += (S.NegativeOne**i * binomial(n//2, i) * + _cos_pow_integrate(m + 2*i, x)) + + elif n == 0: + + # / + # | + # | 1 + # | _ _ _ + # | m + # | cos (x) + # / + # + + res = _cos_pow_integrate(m, x) + else: + + # n < 0 , |m| > |n| + # / + # | + # | m n + # | cos (x) sin (x) dx = + # | + # | + #/ + # / + # | + # 1 m-1 n+1 m - 1 | m-2 n+2 + # _______ cos (x) sin (x) + _______ | cos (x) sin (x) dx + # | + # n + 1 n + 1 | + # / + + res = (Rational(1, n + 1) * cos(x)**(m - 1)*sin(x)**(n + 1) + + Rational(m - 1, n + 1) * + trigintegrate(cos(x)**(m - 2)*sin(x)**(n + 2), x)) + + else: + if m == n: + ##Substitute sin(2x)/2 for sin(x)cos(x) and then Integrate. + res = integrate((sin(2*x)*S.Half)**m, x) + elif (m == -n): + if n < 0: + # Same as the scheme described above. + # the function argument to integrate in the end will + # be 1, this cannot be integrated by trigintegrate. + # Hence use sympy.integrals.integrate. + res = (Rational(1, n + 1) * cos(x)**(m - 1) * sin(x)**(n + 1) + + Rational(m - 1, n + 1) * + integrate(cos(x)**(m - 2) * sin(x)**(n + 2), x)) + else: + res = (Rational(-1, m + 1) * cos(x)**(m + 1) * sin(x)**(n - 1) + + Rational(n - 1, m + 1) * + integrate(cos(x)**(m + 2)*sin(x)**(n - 2), x)) + if conds == 'piecewise': + return Piecewise((res.subs(x, a*x) / a, Ne(a, 0)), (zz, True)) + return res.subs(x, a*x) / a + + +def _sin_pow_integrate(n, x): + if n > 0: + if n == 1: + #Recursion break + return -cos(x) + + # n > 0 + # / / + # | | + # | n -1 n-1 n - 1 | n-2 + # | sin (x) dx = ______ cos (x) sin (x) + _______ | sin (x) dx + # | | + # | n n | + #/ / + # + # + + return (Rational(-1, n) * cos(x) * sin(x)**(n - 1) + + Rational(n - 1, n) * _sin_pow_integrate(n - 2, x)) + + if n < 0: + if n == -1: + ##Make sure this does not come back here again. + ##Recursion breaks here or at n==0. + return trigintegrate(1/sin(x), x) + + # n < 0 + # / / + # | | + # | n 1 n+1 n + 2 | n+2 + # | sin (x) dx = _______ cos (x) sin (x) + _______ | sin (x) dx + # | | + # | n + 1 n + 1 | + #/ / + # + + return (Rational(1, n + 1) * cos(x) * sin(x)**(n + 1) + + Rational(n + 2, n + 1) * _sin_pow_integrate(n + 2, x)) + + else: + #n == 0 + #Recursion break. + return x + + +def _cos_pow_integrate(n, x): + if n > 0: + if n == 1: + #Recursion break. + return sin(x) + + # n > 0 + # / / + # | | + # | n 1 n-1 n - 1 | n-2 + # | sin (x) dx = ______ sin (x) cos (x) + _______ | cos (x) dx + # | | + # | n n | + #/ / + # + + return (Rational(1, n) * sin(x) * cos(x)**(n - 1) + + Rational(n - 1, n) * _cos_pow_integrate(n - 2, x)) + + if n < 0: + if n == -1: + ##Recursion break + return trigintegrate(1/cos(x), x) + + # n < 0 + # / / + # | | + # | n -1 n+1 n + 2 | n+2 + # | cos (x) dx = _______ sin (x) cos (x) + _______ | cos (x) dx + # | | + # | n + 1 n + 1 | + #/ / + # + + return (Rational(-1, n + 1) * sin(x) * cos(x)**(n + 1) + + Rational(n + 2, n + 1) * _cos_pow_integrate(n + 2, x)) + else: + # n == 0 + #Recursion Break. + return x diff --git a/.venv/lib/python3.13/site-packages/sympy/interactive/__init__.py b/.venv/lib/python3.13/site-packages/sympy/interactive/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1b3f043ada6222d79dd52fd28b035e2ea45c5683 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/interactive/__init__.py @@ -0,0 +1,8 @@ +"""Helper module for setting up interactive SymPy sessions. """ + +from .printing import init_printing +from .session import init_session +from .traversal import interactive_traversal + + +__all__ = ['init_printing', 'init_session', 'interactive_traversal'] diff --git a/.venv/lib/python3.13/site-packages/sympy/interactive/printing.py b/.venv/lib/python3.13/site-packages/sympy/interactive/printing.py new file mode 100644 index 0000000000000000000000000000000000000000..2fcc73e3e96a5b7e25f7fc7ebf54a5781c3b15b9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/interactive/printing.py @@ -0,0 +1,532 @@ +"""Tools for setting up printing in interactive sessions. """ + +from io import BytesIO + +from sympy.printing.latex import latex as default_latex +from sympy.printing.preview import preview +from sympy.utilities.misc import debug +from sympy.printing.defaults import Printable +from sympy.external import import_module + + +def _init_python_printing(stringify_func, **settings): + """Setup printing in Python interactive session. """ + import sys + import builtins + + def _displayhook(arg): + """Python's pretty-printer display hook. + + This function was adapted from: + + https://www.python.org/dev/peps/pep-0217/ + + """ + if arg is not None: + builtins._ = None + print(stringify_func(arg, **settings)) + builtins._ = arg + + sys.displayhook = _displayhook + + +def _init_ipython_printing(ip, stringify_func, use_latex, euler, forecolor, + backcolor, fontsize, latex_mode, print_builtin, + latex_printer, scale, **settings): + """Setup printing in IPython interactive session. """ + IPython = import_module("IPython", min_module_version="1.0") + try: + from IPython.lib.latextools import latex_to_png + except ImportError: + pass + + # Guess best font color if none was given based on the ip.colors string. + # From the IPython documentation: + # It has four case-insensitive values: 'nocolor', 'neutral', 'linux', + # 'lightbg'. The default is neutral, which should be legible on either + # dark or light terminal backgrounds. linux is optimised for dark + # backgrounds and lightbg for light ones. + if forecolor is None: + color = ip.colors.lower() + if color == 'lightbg': + forecolor = 'Black' + elif color == 'linux': + forecolor = 'White' + else: + # No idea, go with gray. + forecolor = 'Gray' + debug("init_printing: Automatic foreground color:", forecolor) + + if use_latex == "svg": + extra_preamble = "\n\\special{color %s}" % forecolor + else: + extra_preamble = "" + + imagesize = 'tight' + offset = "0cm,0cm" + resolution = round(150*scale) + dvi = r"-T %s -D %d -bg %s -fg %s -O %s" % ( + imagesize, resolution, backcolor, forecolor, offset) + dvioptions = dvi.split() + + svg_scale = 150/72*scale + dvioptions_svg = ["--no-fonts", "--scale={}".format(svg_scale)] + + debug("init_printing: DVIOPTIONS:", dvioptions) + debug("init_printing: DVIOPTIONS_SVG:", dvioptions_svg) + + latex = latex_printer or default_latex + + def _print_plain(arg, p, cycle): + """caller for pretty, for use in IPython 0.11""" + if _can_print(arg): + p.text(stringify_func(arg)) + else: + p.text(IPython.lib.pretty.pretty(arg)) + + def _preview_wrapper(o): + exprbuffer = BytesIO() + try: + preview(o, output='png', viewer='BytesIO', euler=euler, + outputbuffer=exprbuffer, extra_preamble=extra_preamble, + dvioptions=dvioptions, fontsize=fontsize) + except Exception as e: + # IPython swallows exceptions + debug("png printing:", "_preview_wrapper exception raised:", + repr(e)) + raise + return exprbuffer.getvalue() + + def _svg_wrapper(o): + exprbuffer = BytesIO() + try: + preview(o, output='svg', viewer='BytesIO', euler=euler, + outputbuffer=exprbuffer, extra_preamble=extra_preamble, + dvioptions=dvioptions_svg, fontsize=fontsize) + except Exception as e: + # IPython swallows exceptions + debug("svg printing:", "_preview_wrapper exception raised:", + repr(e)) + raise + return exprbuffer.getvalue().decode('utf-8') + + def _matplotlib_wrapper(o): + # mathtext can't render some LaTeX commands. For example, it can't + # render any LaTeX environments such as array or matrix. So here we + # ensure that if mathtext fails to render, we return None. + try: + try: + return latex_to_png(o, color=forecolor, scale=scale) + except TypeError: # Old IPython version without color and scale + return latex_to_png(o) + except ValueError as e: + debug('matplotlib exception caught:', repr(e)) + return None + + + # Hook methods for builtin SymPy printers + printing_hooks = ('_latex', '_sympystr', '_pretty', '_sympyrepr') + + + def _can_print(o): + """Return True if type o can be printed with one of the SymPy printers. + + If o is a container type, this is True if and only if every element of + o can be printed in this way. + """ + + try: + # If you're adding another type, make sure you add it to printable_types + # later in this file as well + + builtin_types = (list, tuple, set, frozenset) + if isinstance(o, builtin_types): + # If the object is a custom subclass with a custom str or + # repr, use that instead. + if (type(o).__str__ not in (i.__str__ for i in builtin_types) or + type(o).__repr__ not in (i.__repr__ for i in builtin_types)): + return False + return all(_can_print(i) for i in o) + elif isinstance(o, dict): + return all(_can_print(i) and _can_print(o[i]) for i in o) + elif isinstance(o, bool): + return False + elif isinstance(o, Printable): + # types known to SymPy + return True + elif any(hasattr(o, hook) for hook in printing_hooks): + # types which add support themselves + return True + elif isinstance(o, (float, int)) and print_builtin: + return True + return False + except RuntimeError: + return False + # This is in case maximum recursion depth is reached. + # Since RecursionError is for versions of Python 3.5+ + # so this is to guard against RecursionError for older versions. + + def _print_latex_png(o): + """ + A function that returns a png rendered by an external latex + distribution, falling back to matplotlib rendering + """ + if _can_print(o): + s = latex(o, mode=latex_mode, **settings) + if latex_mode == 'plain': + s = '$\\displaystyle %s$' % s + try: + return _preview_wrapper(s) + except RuntimeError as e: + debug('preview failed with:', repr(e), + ' Falling back to matplotlib backend') + if latex_mode != 'inline': + s = latex(o, mode='inline', **settings) + return _matplotlib_wrapper(s) + + def _print_latex_svg(o): + """ + A function that returns a svg rendered by an external latex + distribution, no fallback available. + """ + if _can_print(o): + s = latex(o, mode=latex_mode, **settings) + if latex_mode == 'plain': + s = '$\\displaystyle %s$' % s + try: + return _svg_wrapper(s) + except RuntimeError as e: + debug('preview failed with:', repr(e), + ' No fallback available.') + + def _print_latex_matplotlib(o): + """ + A function that returns a png rendered by mathtext + """ + if _can_print(o): + s = latex(o, mode='inline', **settings) + return _matplotlib_wrapper(s) + + def _print_latex_text(o): + """ + A function to generate the latex representation of SymPy expressions. + """ + if _can_print(o): + s = latex(o, mode=latex_mode, **settings) + if latex_mode == 'plain': + return '$\\displaystyle %s$' % s + return s + + # Printable is our own type, so we handle it with methods instead of + # the approach required by builtin types. This allows downstream + # packages to override the methods in their own subclasses of Printable, + # which avoids the effects of gh-16002. + printable_types = [float, tuple, list, set, frozenset, dict, int] + + plaintext_formatter = ip.display_formatter.formatters['text/plain'] + + # Exception to the rule above: IPython has better dispatching rules + # for plaintext printing (xref ipython/ipython#8938), and we can't + # use `_repr_pretty_` without hitting a recursion error in _print_plain. + for cls in printable_types + [Printable]: + plaintext_formatter.for_type(cls, _print_plain) + + svg_formatter = ip.display_formatter.formatters['image/svg+xml'] + if use_latex in ('svg', ): + debug("init_printing: using svg formatter") + for cls in printable_types: + svg_formatter.for_type(cls, _print_latex_svg) + Printable._repr_svg_ = _print_latex_svg + else: + debug("init_printing: not using any svg formatter") + for cls in printable_types: + # Better way to set this, but currently does not work in IPython + #png_formatter.for_type(cls, None) + if cls in svg_formatter.type_printers: + svg_formatter.type_printers.pop(cls) + Printable._repr_svg_ = Printable._repr_disabled + + png_formatter = ip.display_formatter.formatters['image/png'] + if use_latex in (True, 'png'): + debug("init_printing: using png formatter") + for cls in printable_types: + png_formatter.for_type(cls, _print_latex_png) + Printable._repr_png_ = _print_latex_png + elif use_latex == 'matplotlib': + debug("init_printing: using matplotlib formatter") + for cls in printable_types: + png_formatter.for_type(cls, _print_latex_matplotlib) + Printable._repr_png_ = _print_latex_matplotlib + else: + debug("init_printing: not using any png formatter") + for cls in printable_types: + # Better way to set this, but currently does not work in IPython + #png_formatter.for_type(cls, None) + if cls in png_formatter.type_printers: + png_formatter.type_printers.pop(cls) + Printable._repr_png_ = Printable._repr_disabled + + latex_formatter = ip.display_formatter.formatters['text/latex'] + if use_latex in (True, 'mathjax'): + debug("init_printing: using mathjax formatter") + for cls in printable_types: + latex_formatter.for_type(cls, _print_latex_text) + Printable._repr_latex_ = _print_latex_text + else: + debug("init_printing: not using text/latex formatter") + for cls in printable_types: + # Better way to set this, but currently does not work in IPython + #latex_formatter.for_type(cls, None) + if cls in latex_formatter.type_printers: + latex_formatter.type_printers.pop(cls) + Printable._repr_latex_ = Printable._repr_disabled + +def _is_ipython(shell): + """Is a shell instance an IPython shell?""" + # shortcut, so we don't import IPython if we don't have to + from sys import modules + if 'IPython' not in modules: + return False + try: + from IPython.core.interactiveshell import InteractiveShell + except ImportError: + # IPython < 0.11 + try: + from IPython.iplib import InteractiveShell + except ImportError: + # Reaching this points means IPython has changed in a backward-incompatible way + # that we don't know about. Warn? + return False + return isinstance(shell, InteractiveShell) + +# Used by the doctester to override the default for no_global +NO_GLOBAL = False + +def init_printing(pretty_print=True, order=None, use_unicode=None, + use_latex=None, wrap_line=None, num_columns=None, + no_global=False, ip=None, euler=False, forecolor=None, + backcolor='Transparent', fontsize='10pt', + latex_mode='plain', print_builtin=True, + str_printer=None, pretty_printer=None, + latex_printer=None, scale=1.0, **settings): + r""" + Initializes pretty-printer depending on the environment. + + Parameters + ========== + + pretty_print : bool, default=True + If ``True``, use :func:`~.pretty_print` to stringify or the provided pretty + printer; if ``False``, use :func:`~.sstrrepr` to stringify or the provided string + printer. + order : string or None, default='lex' + There are a few different settings for this parameter: + ``'lex'`` (default), which is lexographic order; + ``'grlex'``, which is graded lexographic order; + ``'grevlex'``, which is reversed graded lexographic order; + ``'old'``, which is used for compatibility reasons and for long expressions; + ``None``, which sets it to lex. + use_unicode : bool or None, default=None + If ``True``, use unicode characters; + if ``False``, do not use unicode characters; + if ``None``, make a guess based on the environment. + use_latex : string, bool, or None, default=None + If ``True``, use default LaTeX rendering in GUI interfaces (png and + mathjax); + if ``False``, do not use LaTeX rendering; + if ``None``, make a guess based on the environment; + if ``'png'``, enable LaTeX rendering with an external LaTeX compiler, + falling back to matplotlib if external compilation fails; + if ``'matplotlib'``, enable LaTeX rendering with matplotlib; + if ``'mathjax'``, enable LaTeX text generation, for example MathJax + rendering in IPython notebook or text rendering in LaTeX documents; + if ``'svg'``, enable LaTeX rendering with an external latex compiler, + no fallback + wrap_line : bool + If True, lines will wrap at the end; if False, they will not wrap + but continue as one line. This is only relevant if ``pretty_print`` is + True. + num_columns : int or None, default=None + If ``int``, number of columns before wrapping is set to num_columns; if + ``None``, number of columns before wrapping is set to terminal width. + This is only relevant if ``pretty_print`` is ``True``. + no_global : bool, default=False + If ``True``, the settings become system wide; + if ``False``, use just for this console/session. + ip : An interactive console + This can either be an instance of IPython, + or a class that derives from code.InteractiveConsole. + euler : bool, optional, default=False + Loads the euler package in the LaTeX preamble for handwritten style + fonts (https://www.ctan.org/pkg/euler). + forecolor : string or None, optional, default=None + DVI setting for foreground color. ``None`` means that either ``'Black'``, + ``'White'``, or ``'Gray'`` will be selected based on a guess of the IPython + terminal color setting. See notes. + backcolor : string, optional, default='Transparent' + DVI setting for background color. See notes. + fontsize : string or int, optional, default='10pt' + A font size to pass to the LaTeX documentclass function in the + preamble. Note that the options are limited by the documentclass. + Consider using scale instead. + latex_mode : string, optional, default='plain' + The mode used in the LaTeX printer. Can be one of: + ``{'inline'|'plain'|'equation'|'equation*'}``. + print_builtin : boolean, optional, default=True + If ``True`` then floats and integers will be printed. If ``False`` the + printer will only print SymPy types. + str_printer : function, optional, default=None + A custom string printer function. This should mimic + :func:`~.sstrrepr`. + pretty_printer : function, optional, default=None + A custom pretty printer. This should mimic :func:`~.pretty`. + latex_printer : function, optional, default=None + A custom LaTeX printer. This should mimic :func:`~.latex`. + scale : float, optional, default=1.0 + Scale the LaTeX output when using the ``'png'`` or ``'svg'`` backends. + Useful for high dpi screens. + settings : + Any additional settings for the ``latex`` and ``pretty`` commands can + be used to fine-tune the output. + + Examples + ======== + + >>> from sympy.interactive import init_printing + >>> from sympy import Symbol, sqrt + >>> from sympy.abc import x, y + >>> sqrt(5) + sqrt(5) + >>> init_printing(pretty_print=True) # doctest: +SKIP + >>> sqrt(5) # doctest: +SKIP + ___ + \/ 5 + >>> theta = Symbol('theta') # doctest: +SKIP + >>> init_printing(use_unicode=True) # doctest: +SKIP + >>> theta # doctest: +SKIP + \u03b8 + >>> init_printing(use_unicode=False) # doctest: +SKIP + >>> theta # doctest: +SKIP + theta + >>> init_printing(order='lex') # doctest: +SKIP + >>> str(y + x + y**2 + x**2) # doctest: +SKIP + x**2 + x + y**2 + y + >>> init_printing(order='grlex') # doctest: +SKIP + >>> str(y + x + y**2 + x**2) # doctest: +SKIP + x**2 + x + y**2 + y + >>> init_printing(order='grevlex') # doctest: +SKIP + >>> str(y * x**2 + x * y**2) # doctest: +SKIP + x**2*y + x*y**2 + >>> init_printing(order='old') # doctest: +SKIP + >>> str(x**2 + y**2 + x + y) # doctest: +SKIP + x**2 + x + y**2 + y + >>> init_printing(num_columns=10) # doctest: +SKIP + >>> x**2 + x + y**2 + y # doctest: +SKIP + x + y + + x**2 + y**2 + + Notes + ===== + + The foreground and background colors can be selected when using ``'png'`` or + ``'svg'`` LaTeX rendering. Note that before the ``init_printing`` command is + executed, the LaTeX rendering is handled by the IPython console and not SymPy. + + The colors can be selected among the 68 standard colors known to ``dvips``, + for a list see [1]_. In addition, the background color can be + set to ``'Transparent'`` (which is the default value). + + When using the ``'Auto'`` foreground color, the guess is based on the + ``colors`` variable in the IPython console, see [2]_. Hence, if + that variable is set correctly in your IPython console, there is a high + chance that the output will be readable, although manual settings may be + needed. + + + References + ========== + + .. [1] https://en.wikibooks.org/wiki/LaTeX/Colors#The_68_standard_colors_known_to_dvips + + .. [2] https://ipython.readthedocs.io/en/stable/config/details.html#terminal-colors + + See Also + ======== + + sympy.printing.latex + sympy.printing.pretty + + """ + import sys + from sympy.printing.printer import Printer + + if pretty_print: + if pretty_printer is not None: + stringify_func = pretty_printer + else: + from sympy.printing import pretty as stringify_func + else: + if str_printer is not None: + stringify_func = str_printer + else: + from sympy.printing import sstrrepr as stringify_func + + # Even if ip is not passed, double check that not in IPython shell + in_ipython = False + if ip is None: + try: + ip = get_ipython() + except NameError: + pass + else: + in_ipython = (ip is not None) + + if ip and not in_ipython: + in_ipython = _is_ipython(ip) + + if in_ipython and pretty_print: + try: + from IPython.terminal.interactiveshell import TerminalInteractiveShell + from code import InteractiveConsole + except ImportError: + pass + else: + # This will be True if we are in the qtconsole or notebook + if not isinstance(ip, (InteractiveConsole, TerminalInteractiveShell)) \ + and 'ipython-console' not in ''.join(sys.argv): + if use_unicode is None: + debug("init_printing: Setting use_unicode to True") + use_unicode = True + if use_latex is None: + debug("init_printing: Setting use_latex to True") + use_latex = True + + if not NO_GLOBAL and not no_global: + Printer.set_global_settings(order=order, use_unicode=use_unicode, + wrap_line=wrap_line, num_columns=num_columns) + else: + _stringify_func = stringify_func + + if pretty_print: + stringify_func = lambda expr, **settings: \ + _stringify_func(expr, order=order, + use_unicode=use_unicode, + wrap_line=wrap_line, + num_columns=num_columns, + **settings) + else: + stringify_func = \ + lambda expr, **settings: _stringify_func( + expr, order=order, **settings) + + if in_ipython: + mode_in_settings = settings.pop("mode", None) + if mode_in_settings: + debug("init_printing: Mode is not able to be set due to internals" + "of IPython printing") + _init_ipython_printing(ip, stringify_func, use_latex, euler, + forecolor, backcolor, fontsize, latex_mode, + print_builtin, latex_printer, scale, + **settings) + else: + _init_python_printing(stringify_func, **settings) diff --git a/.venv/lib/python3.13/site-packages/sympy/interactive/session.py b/.venv/lib/python3.13/site-packages/sympy/interactive/session.py new file mode 100644 index 0000000000000000000000000000000000000000..348b0938d69e5e7ffa9510f7d9ac759eb6683b8f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/interactive/session.py @@ -0,0 +1,463 @@ +"""Tools for setting up interactive sessions. """ + +from sympy.external.gmpy import GROUND_TYPES +from sympy.external.importtools import version_tuple + +from sympy.interactive.printing import init_printing + +from sympy.utilities.misc import ARCH + +preexec_source = """\ +from sympy import * +x, y, z, t = symbols('x y z t') +k, m, n = symbols('k m n', integer=True) +f, g, h = symbols('f g h', cls=Function) +init_printing() +""" + +verbose_message = """\ +These commands were executed: +%(source)s +Documentation can be found at https://docs.sympy.org/%(version)s +""" + +no_ipython = """\ +Could not locate IPython. Having IPython installed is greatly recommended. +See http://ipython.scipy.org for more details. If you use Debian/Ubuntu, +just install the 'ipython' package and start isympy again. +""" + + +def _make_message(ipython=True, quiet=False, source=None): + """Create a banner for an interactive session. """ + from sympy import __version__ as sympy_version + from sympy import SYMPY_DEBUG + + import sys + import os + + if quiet: + return "" + + python_version = "%d.%d.%d" % sys.version_info[:3] + + if ipython: + shell_name = "IPython" + else: + shell_name = "Python" + + info = ['ground types: %s' % GROUND_TYPES] + + cache = os.getenv('SYMPY_USE_CACHE') + + if cache is not None and cache.lower() == 'no': + info.append('cache: off') + + if SYMPY_DEBUG: + info.append('debugging: on') + + args = shell_name, sympy_version, python_version, ARCH, ', '.join(info) + message = "%s console for SymPy %s (Python %s-%s) (%s)\n" % args + + if source is None: + source = preexec_source + + _source = "" + + for line in source.split('\n')[:-1]: + if not line: + _source += '\n' + else: + _source += '>>> ' + line + '\n' + + doc_version = sympy_version + if 'dev' in doc_version: + doc_version = "dev" + else: + doc_version = "%s/" % doc_version + + message += '\n' + verbose_message % {'source': _source, + 'version': doc_version} + + return message + + +def int_to_Integer(s): + """ + Wrap integer literals with Integer. + + This is based on the decistmt example from + https://docs.python.org/3/library/tokenize.html. + + Only integer literals are converted. Float literals are left alone. + + Examples + ======== + + >>> from sympy import Integer # noqa: F401 + >>> from sympy.interactive.session import int_to_Integer + >>> s = '1.2 + 1/2 - 0x12 + a1' + >>> int_to_Integer(s) + '1.2 +Integer (1 )/Integer (2 )-Integer (0x12 )+a1 ' + >>> s = 'print (1/2)' + >>> int_to_Integer(s) + 'print (Integer (1 )/Integer (2 ))' + >>> exec(s) + 0.5 + >>> exec(int_to_Integer(s)) + 1/2 + """ + from tokenize import generate_tokens, untokenize, NUMBER, NAME, OP + from io import StringIO + + def _is_int(num): + """ + Returns true if string value num (with token NUMBER) represents an integer. + """ + # XXX: Is there something in the standard library that will do this? + if '.' in num or 'j' in num.lower() or 'e' in num.lower(): + return False + return True + + result = [] + g = generate_tokens(StringIO(s).readline) # tokenize the string + for toknum, tokval, _, _, _ in g: + if toknum == NUMBER and _is_int(tokval): # replace NUMBER tokens + result.extend([ + (NAME, 'Integer'), + (OP, '('), + (NUMBER, tokval), + (OP, ')') + ]) + else: + result.append((toknum, tokval)) + return untokenize(result) + + +def enable_automatic_int_sympification(shell): + """ + Allow IPython to automatically convert integer literals to Integer. + """ + import ast + old_run_cell = shell.run_cell + + def my_run_cell(cell, *args, **kwargs): + try: + # Check the cell for syntax errors. This way, the syntax error + # will show the original input, not the transformed input. The + # downside here is that IPython magic like %timeit will not work + # with transformed input (but on the other hand, IPython magic + # that doesn't expect transformed input will continue to work). + ast.parse(cell) + except SyntaxError: + pass + else: + cell = int_to_Integer(cell) + return old_run_cell(cell, *args, **kwargs) + + shell.run_cell = my_run_cell + + +def enable_automatic_symbols(shell): + """Allow IPython to automatically create symbols (``isympy -a``). """ + # XXX: This should perhaps use tokenize, like int_to_Integer() above. + # This would avoid re-executing the code, which can lead to subtle + # issues. For example: + # + # In [1]: a = 1 + # + # In [2]: for i in range(10): + # ...: a += 1 + # ...: + # + # In [3]: a + # Out[3]: 11 + # + # In [4]: a = 1 + # + # In [5]: for i in range(10): + # ...: a += 1 + # ...: print b + # ...: + # b + # b + # b + # b + # b + # b + # b + # b + # b + # b + # + # In [6]: a + # Out[6]: 12 + # + # Note how the for loop is executed again because `b` was not defined, but `a` + # was already incremented once, so the result is that it is incremented + # multiple times. + + import re + re_nameerror = re.compile( + "name '(?P[A-Za-z_][A-Za-z0-9_]*)' is not defined") + + def _handler(self, etype, value, tb, tb_offset=None): + """Handle :exc:`NameError` exception and allow injection of missing symbols. """ + if etype is NameError and tb.tb_next and not tb.tb_next.tb_next: + match = re_nameerror.match(str(value)) + + if match is not None: + # XXX: Make sure Symbol is in scope. Otherwise you'll get infinite recursion. + self.run_cell("%(symbol)s = Symbol('%(symbol)s')" % + {'symbol': match.group("symbol")}, store_history=False) + + try: + code = self.user_ns['In'][-1] + except (KeyError, IndexError): + pass + else: + self.run_cell(code, store_history=False) + return None + finally: + self.run_cell("del %s" % match.group("symbol"), + store_history=False) + + stb = self.InteractiveTB.structured_traceback( + etype, value, tb, tb_offset=tb_offset) + self._showtraceback(etype, value, stb) + + shell.set_custom_exc((NameError,), _handler) + + +def init_ipython_session(shell=None, argv=[], auto_symbols=False, auto_int_to_Integer=False): + """Construct new IPython session. """ + import IPython + + if version_tuple(IPython.__version__) >= version_tuple('0.11'): + if not shell: + # use an app to parse the command line, and init config + # IPython 1.0 deprecates the frontend module, so we import directly + # from the terminal module to prevent a deprecation message from being + # shown. + if version_tuple(IPython.__version__) >= version_tuple('1.0'): + from IPython.terminal import ipapp + else: + from IPython.frontend.terminal import ipapp + app = ipapp.TerminalIPythonApp() + + # don't draw IPython banner during initialization: + app.display_banner = False + app.initialize(argv) + + shell = app.shell + + if auto_symbols: + enable_automatic_symbols(shell) + if auto_int_to_Integer: + enable_automatic_int_sympification(shell) + + return shell + else: + from IPython.Shell import make_IPython + return make_IPython(argv) + + +def init_python_session(): + """Construct new Python session. """ + from code import InteractiveConsole + + class SymPyConsole(InteractiveConsole): + """An interactive console with readline support. """ + + def __init__(self): + ns_locals = {} + InteractiveConsole.__init__(self, locals=ns_locals) + try: + import rlcompleter + import readline + except ImportError: + pass + else: + import os + import atexit + + readline.set_completer(rlcompleter.Completer(ns_locals).complete) + readline.parse_and_bind('tab: complete') + + if hasattr(readline, 'read_history_file'): + history = os.path.expanduser('~/.sympy-history') + + try: + readline.read_history_file(history) + except OSError: + pass + + atexit.register(readline.write_history_file, history) + + return SymPyConsole() + + +def init_session(ipython=None, pretty_print=True, order=None, + use_unicode=None, use_latex=None, quiet=False, auto_symbols=False, + auto_int_to_Integer=False, str_printer=None, pretty_printer=None, + latex_printer=None, argv=[]): + """ + Initialize an embedded IPython or Python session. The IPython session is + initiated with the --pylab option, without the numpy imports, so that + matplotlib plotting can be interactive. + + Parameters + ========== + + pretty_print: boolean + If True, use pretty_print to stringify; + if False, use sstrrepr to stringify. + order: string or None + There are a few different settings for this parameter: + lex (default), which is lexographic order; + grlex, which is graded lexographic order; + grevlex, which is reversed graded lexographic order; + old, which is used for compatibility reasons and for long expressions; + None, which sets it to lex. + use_unicode: boolean or None + If True, use unicode characters; + if False, do not use unicode characters. + use_latex: boolean or None + If True, use latex rendering if IPython GUI's; + if False, do not use latex rendering. + quiet: boolean + If True, init_session will not print messages regarding its status; + if False, init_session will print messages regarding its status. + auto_symbols: boolean + If True, IPython will automatically create symbols for you. + If False, it will not. + The default is False. + auto_int_to_Integer: boolean + If True, IPython will automatically wrap int literals with Integer, so + that things like 1/2 give Rational(1, 2). + If False, it will not. + The default is False. + ipython: boolean or None + If True, printing will initialize for an IPython console; + if False, printing will initialize for a normal console; + The default is None, which automatically determines whether we are in + an ipython instance or not. + str_printer: function, optional, default=None + A custom string printer function. This should mimic + sympy.printing.sstrrepr(). + pretty_printer: function, optional, default=None + A custom pretty printer. This should mimic sympy.printing.pretty(). + latex_printer: function, optional, default=None + A custom LaTeX printer. This should mimic sympy.printing.latex() + This should mimic sympy.printing.latex(). + argv: list of arguments for IPython + See sympy.bin.isympy for options that can be used to initialize IPython. + + See Also + ======== + + sympy.interactive.printing.init_printing: for examples and the rest of the parameters. + + + Examples + ======== + + >>> from sympy import init_session, Symbol, sin, sqrt + >>> sin(x) #doctest: +SKIP + NameError: name 'x' is not defined + >>> init_session() #doctest: +SKIP + >>> sin(x) #doctest: +SKIP + sin(x) + >>> sqrt(5) #doctest: +SKIP + ___ + \\/ 5 + >>> init_session(pretty_print=False) #doctest: +SKIP + >>> sqrt(5) #doctest: +SKIP + sqrt(5) + >>> y + x + y**2 + x**2 #doctest: +SKIP + x**2 + x + y**2 + y + >>> init_session(order='grlex') #doctest: +SKIP + >>> y + x + y**2 + x**2 #doctest: +SKIP + x**2 + y**2 + x + y + >>> init_session(order='grevlex') #doctest: +SKIP + >>> y * x**2 + x * y**2 #doctest: +SKIP + x**2*y + x*y**2 + >>> init_session(order='old') #doctest: +SKIP + >>> x**2 + y**2 + x + y #doctest: +SKIP + x + y + x**2 + y**2 + >>> theta = Symbol('theta') #doctest: +SKIP + >>> theta #doctest: +SKIP + theta + >>> init_session(use_unicode=True) #doctest: +SKIP + >>> theta # doctest: +SKIP + \u03b8 + """ + import sys + + in_ipython = False + + if ipython is not False: + try: + import IPython + except ImportError: + if ipython is True: + raise RuntimeError("IPython is not available on this system") + ip = None + else: + try: + from IPython import get_ipython + ip = get_ipython() + except ImportError: + ip = None + in_ipython = bool(ip) + if ipython is None: + ipython = in_ipython + + if ipython is False: + ip = init_python_session() + mainloop = ip.interact + else: + ip = init_ipython_session(ip, argv=argv, auto_symbols=auto_symbols, + auto_int_to_Integer=auto_int_to_Integer) + + if version_tuple(IPython.__version__) >= version_tuple('0.11'): + # runsource is gone, use run_cell instead, which doesn't + # take a symbol arg. The second arg is `store_history`, + # and False means don't add the line to IPython's history. + ip.runsource = lambda src, symbol='exec': ip.run_cell(src, False) + + # Enable interactive plotting using pylab. + try: + ip.enable_pylab(import_all=False) + except Exception: + # Causes an import error if matplotlib is not installed. + # Causes other errors (depending on the backend) if there + # is no display, or if there is some problem in the + # backend, so we have a bare "except Exception" here + pass + if not in_ipython: + mainloop = ip.mainloop + + if auto_symbols and (not ipython or version_tuple(IPython.__version__) < version_tuple('0.11')): + raise RuntimeError("automatic construction of symbols is possible only in IPython 0.11 or above") + if auto_int_to_Integer and (not ipython or version_tuple(IPython.__version__) < version_tuple('0.11')): + raise RuntimeError("automatic int to Integer transformation is possible only in IPython 0.11 or above") + + _preexec_source = preexec_source + + ip.runsource(_preexec_source, symbol='exec') + init_printing(pretty_print=pretty_print, order=order, + use_unicode=use_unicode, use_latex=use_latex, ip=ip, + str_printer=str_printer, pretty_printer=pretty_printer, + latex_printer=latex_printer) + + message = _make_message(ipython, quiet, _preexec_source) + + if not in_ipython: + print(message) + mainloop() + sys.exit('Exiting ...') + else: + print(message) + import atexit + atexit.register(lambda: print("Exiting ...\n")) diff --git a/.venv/lib/python3.13/site-packages/sympy/interactive/traversal.py b/.venv/lib/python3.13/site-packages/sympy/interactive/traversal.py new file mode 100644 index 0000000000000000000000000000000000000000..1315ec4ef7868b666bb6b978b3d8b20442d100b0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/interactive/traversal.py @@ -0,0 +1,95 @@ +from sympy.core.basic import Basic +from sympy.printing import pprint + +import random + +def interactive_traversal(expr): + """Traverse a tree asking a user which branch to choose. """ + + RED, BRED = '\033[0;31m', '\033[1;31m' + GREEN, BGREEN = '\033[0;32m', '\033[1;32m' + YELLOW, BYELLOW = '\033[0;33m', '\033[1;33m' # noqa + BLUE, BBLUE = '\033[0;34m', '\033[1;34m' # noqa + MAGENTA, BMAGENTA = '\033[0;35m', '\033[1;35m'# noqa + CYAN, BCYAN = '\033[0;36m', '\033[1;36m' # noqa + END = '\033[0m' + + def cprint(*args): + print("".join(map(str, args)) + END) + + def _interactive_traversal(expr, stage): + if stage > 0: + print() + + cprint("Current expression (stage ", BYELLOW, stage, END, "):") + print(BCYAN) + pprint(expr) + print(END) + + if isinstance(expr, Basic): + if expr.is_Add: + args = expr.as_ordered_terms() + elif expr.is_Mul: + args = expr.as_ordered_factors() + else: + args = expr.args + elif hasattr(expr, "__iter__"): + args = list(expr) + else: + return expr + + n_args = len(args) + + if not n_args: + return expr + + for i, arg in enumerate(args): + cprint(GREEN, "[", BGREEN, i, GREEN, "] ", BLUE, type(arg), END) + pprint(arg) + print() + + if n_args == 1: + choices = '0' + else: + choices = '0-%d' % (n_args - 1) + + try: + choice = input("Your choice [%s,f,l,r,d,?]: " % choices) + except EOFError: + result = expr + print() + else: + if choice == '?': + cprint(RED, "%s - select subexpression with the given index" % + choices) + cprint(RED, "f - select the first subexpression") + cprint(RED, "l - select the last subexpression") + cprint(RED, "r - select a random subexpression") + cprint(RED, "d - done\n") + + result = _interactive_traversal(expr, stage) + elif choice in ('d', ''): + result = expr + elif choice == 'f': + result = _interactive_traversal(args[0], stage + 1) + elif choice == 'l': + result = _interactive_traversal(args[-1], stage + 1) + elif choice == 'r': + result = _interactive_traversal(random.choice(args), stage + 1) + else: + try: + choice = int(choice) + except ValueError: + cprint(BRED, + "Choice must be a number in %s range\n" % choices) + result = _interactive_traversal(expr, stage) + else: + if choice < 0 or choice >= n_args: + cprint(BRED, "Choice must be in %s range\n" % choices) + result = _interactive_traversal(expr, stage) + else: + result = _interactive_traversal(args[choice], stage + 1) + + return result + + return _interactive_traversal(expr, 0) diff --git a/.venv/lib/python3.13/site-packages/sympy/liealgebras/__init__.py b/.venv/lib/python3.13/site-packages/sympy/liealgebras/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d023d86f2c6f0c64d7ac460c50eedc355e78b21f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/liealgebras/__init__.py @@ -0,0 +1,3 @@ +from sympy.liealgebras.cartan_type import CartanType + +__all__ = ['CartanType'] diff --git a/.venv/lib/python3.13/site-packages/sympy/liealgebras/cartan_matrix.py b/.venv/lib/python3.13/site-packages/sympy/liealgebras/cartan_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..2d29b37bc9a1a26790ee88b5902951afe4fc4560 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/liealgebras/cartan_matrix.py @@ -0,0 +1,25 @@ +from .cartan_type import CartanType + +def CartanMatrix(ct): + """Access the Cartan matrix of a specific Lie algebra + + Examples + ======== + + >>> from sympy.liealgebras.cartan_matrix import CartanMatrix + >>> CartanMatrix("A2") + Matrix([ + [ 2, -1], + [-1, 2]]) + + >>> CartanMatrix(['C', 3]) + Matrix([ + [ 2, -1, 0], + [-1, 2, -1], + [ 0, -2, 2]]) + + This method works by returning the Cartan matrix + which corresponds to Cartan type t. + """ + + return CartanType(ct).cartan_matrix() diff --git a/.venv/lib/python3.13/site-packages/sympy/liealgebras/cartan_type.py b/.venv/lib/python3.13/site-packages/sympy/liealgebras/cartan_type.py new file mode 100644 index 0000000000000000000000000000000000000000..16bb152469238ea912a30c2d0f8210d6f729bdb1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/liealgebras/cartan_type.py @@ -0,0 +1,73 @@ +from sympy.core import Atom, Basic + + +class CartanType_generator(): + """ + Constructor for actually creating things + """ + def __call__(self, *args): + c = args[0] + if isinstance(c, list): + letter, n = c[0], int(c[1]) + elif isinstance(c, str): + letter, n = c[0], int(c[1:]) + else: + raise TypeError("Argument must be a string (e.g. 'A3') or a list (e.g. ['A', 3])") + + if n < 0: + raise ValueError("Lie algebra rank cannot be negative") + if letter == "A": + from . import type_a + return type_a.TypeA(n) + if letter == "B": + from . import type_b + return type_b.TypeB(n) + + if letter == "C": + from . import type_c + return type_c.TypeC(n) + + if letter == "D": + from . import type_d + return type_d.TypeD(n) + + if letter == "E": + if n >= 6 and n <= 8: + from . import type_e + return type_e.TypeE(n) + + if letter == "F": + if n == 4: + from . import type_f + return type_f.TypeF(n) + + if letter == "G": + if n == 2: + from . import type_g + return type_g.TypeG(n) + +CartanType = CartanType_generator() + + +class Standard_Cartan(Atom): + """ + Concrete base class for Cartan types such as A4, etc + """ + + def __new__(cls, series, n): + obj = Basic.__new__(cls) + obj.n = n + obj.series = series + return obj + + def rank(self): + """ + Returns the rank of the Lie algebra + """ + return self.n + + def series(self): + """ + Returns the type of the Lie algebra + """ + return self.series diff --git a/.venv/lib/python3.13/site-packages/sympy/liealgebras/dynkin_diagram.py b/.venv/lib/python3.13/site-packages/sympy/liealgebras/dynkin_diagram.py new file mode 100644 index 0000000000000000000000000000000000000000..cc9e2dac4d54490b803eeaf9637cb9b66b01f058 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/liealgebras/dynkin_diagram.py @@ -0,0 +1,24 @@ +from .cartan_type import CartanType + + +def DynkinDiagram(t): + """Display the Dynkin diagram of a given Lie algebra + + Works by generating the CartanType for the input, t, and then returning the + Dynkin diagram method from the individual classes. + + Examples + ======== + + >>> from sympy.liealgebras.dynkin_diagram import DynkinDiagram + >>> print(DynkinDiagram("A3")) + 0---0---0 + 1 2 3 + + >>> print(DynkinDiagram("B4")) + 0---0---0=>=0 + 1 2 3 4 + + """ + + return CartanType(t).dynkin_diagram() diff --git a/.venv/lib/python3.13/site-packages/sympy/liealgebras/root_system.py b/.venv/lib/python3.13/site-packages/sympy/liealgebras/root_system.py new file mode 100644 index 0000000000000000000000000000000000000000..36eb24605e78bbdc669736910d89be5606df1389 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/liealgebras/root_system.py @@ -0,0 +1,196 @@ +from .cartan_type import CartanType +from sympy.core.basic import Atom + +class RootSystem(Atom): + """Represent the root system of a simple Lie algebra + + Every simple Lie algebra has a unique root system. To find the root + system, we first consider the Cartan subalgebra of g, which is the maximal + abelian subalgebra, and consider the adjoint action of g on this + subalgebra. There is a root system associated with this action. Now, a + root system over a vector space V is a set of finite vectors Phi (called + roots), which satisfy: + + 1. The roots span V + 2. The only scalar multiples of x in Phi are x and -x + 3. For every x in Phi, the set Phi is closed under reflection + through the hyperplane perpendicular to x. + 4. If x and y are roots in Phi, then the projection of y onto + the line through x is a half-integral multiple of x. + + Now, there is a subset of Phi, which we will call Delta, such that: + 1. Delta is a basis of V + 2. Each root x in Phi can be written x = sum k_y y for y in Delta + + The elements of Delta are called the simple roots. + Therefore, we see that the simple roots span the root space of a given + simple Lie algebra. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Root_system + .. [2] Lie Algebras and Representation Theory - Humphreys + + """ + + def __new__(cls, cartantype): + """Create a new RootSystem object + + This method assigns an attribute called cartan_type to each instance of + a RootSystem object. When an instance of RootSystem is called, it + needs an argument, which should be an instance of a simple Lie algebra. + We then take the CartanType of this argument and set it as the + cartan_type attribute of the RootSystem instance. + + """ + obj = Atom.__new__(cls) + obj.cartan_type = CartanType(cartantype) + return obj + + def simple_roots(self): + """Generate the simple roots of the Lie algebra + + The rank of the Lie algebra determines the number of simple roots that + it has. This method obtains the rank of the Lie algebra, and then uses + the simple_root method from the Lie algebra classes to generate all the + simple roots. + + Examples + ======== + + >>> from sympy.liealgebras.root_system import RootSystem + >>> c = RootSystem("A3") + >>> roots = c.simple_roots() + >>> roots + {1: [1, -1, 0, 0], 2: [0, 1, -1, 0], 3: [0, 0, 1, -1]} + + """ + n = self.cartan_type.rank() + roots = {i: self.cartan_type.simple_root(i) for i in range(1, n+1)} + return roots + + + def all_roots(self): + """Generate all the roots of a given root system + + The result is a dictionary where the keys are integer numbers. It + generates the roots by getting the dictionary of all positive roots + from the bases classes, and then taking each root, and multiplying it + by -1 and adding it to the dictionary. In this way all the negative + roots are generated. + + """ + alpha = self.cartan_type.positive_roots() + keys = list(alpha.keys()) + k = max(keys) + for val in keys: + k += 1 + root = alpha[val] + newroot = [-x for x in root] + alpha[k] = newroot + return alpha + + def root_space(self): + """Return the span of the simple roots + + The root space is the vector space spanned by the simple roots, i.e. it + is a vector space with a distinguished basis, the simple roots. This + method returns a string that represents the root space as the span of + the simple roots, alpha[1],...., alpha[n]. + + Examples + ======== + + >>> from sympy.liealgebras.root_system import RootSystem + >>> c = RootSystem("A3") + >>> c.root_space() + 'alpha[1] + alpha[2] + alpha[3]' + + """ + n = self.cartan_type.rank() + rs = " + ".join("alpha["+str(i) +"]" for i in range(1, n+1)) + return rs + + def add_simple_roots(self, root1, root2): + """Add two simple roots together + + The function takes as input two integers, root1 and root2. It then + uses these integers as keys in the dictionary of simple roots, and gets + the corresponding simple roots, and then adds them together. + + Examples + ======== + + >>> from sympy.liealgebras.root_system import RootSystem + >>> c = RootSystem("A3") + >>> newroot = c.add_simple_roots(1, 2) + >>> newroot + [1, 0, -1, 0] + + """ + + alpha = self.simple_roots() + if root1 > len(alpha) or root2 > len(alpha): + raise ValueError("You've used a root that doesn't exist!") + a1 = alpha[root1] + a2 = alpha[root2] + newroot = [_a1 + _a2 for _a1, _a2 in zip(a1, a2)] + return newroot + + def add_as_roots(self, root1, root2): + """Add two roots together if and only if their sum is also a root + + It takes as input two vectors which should be roots. It then computes + their sum and checks if it is in the list of all possible roots. If it + is, it returns the sum. Otherwise it returns a string saying that the + sum is not a root. + + Examples + ======== + + >>> from sympy.liealgebras.root_system import RootSystem + >>> c = RootSystem("A3") + >>> c.add_as_roots([1, 0, -1, 0], [0, 0, 1, -1]) + [1, 0, 0, -1] + >>> c.add_as_roots([1, -1, 0, 0], [0, 0, -1, 1]) + 'The sum of these two roots is not a root' + + """ + alpha = self.all_roots() + newroot = [r1 + r2 for r1, r2 in zip(root1, root2)] + if newroot in alpha.values(): + return newroot + else: + return "The sum of these two roots is not a root" + + + def cartan_matrix(self): + """Cartan matrix of Lie algebra associated with this root system + + Examples + ======== + + >>> from sympy.liealgebras.root_system import RootSystem + >>> c = RootSystem("A3") + >>> c.cartan_matrix() + Matrix([ + [ 2, -1, 0], + [-1, 2, -1], + [ 0, -1, 2]]) + """ + return self.cartan_type.cartan_matrix() + + def dynkin_diagram(self): + """Dynkin diagram of the Lie algebra associated with this root system + + Examples + ======== + + >>> from sympy.liealgebras.root_system import RootSystem + >>> c = RootSystem("A3") + >>> print(c.dynkin_diagram()) + 0---0---0 + 1 2 3 + """ + return self.cartan_type.dynkin_diagram() diff --git a/.venv/lib/python3.13/site-packages/sympy/liealgebras/type_a.py b/.venv/lib/python3.13/site-packages/sympy/liealgebras/type_a.py new file mode 100644 index 0000000000000000000000000000000000000000..96dc615366ae20d668d651620ac088f15751c50e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/liealgebras/type_a.py @@ -0,0 +1,164 @@ +from sympy.liealgebras.cartan_type import Standard_Cartan +from sympy.core.backend import eye + + +class TypeA(Standard_Cartan): + """ + This class contains the information about + the A series of simple Lie algebras. + ==== + """ + + def __new__(cls, n): + if n < 1: + raise ValueError("n cannot be less than 1") + return Standard_Cartan.__new__(cls, "A", n) + + + def dimension(self): + """Dimension of the vector space V underlying the Lie algebra + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("A4") + >>> c.dimension() + 5 + """ + return self.n+1 + + + def basic_root(self, i, j): + """ + This is a method just to generate roots + with a 1 iin the ith position and a -1 + in the jth position. + + """ + + n = self.n + root = [0]*(n+1) + root[i] = 1 + root[j] = -1 + return root + + def simple_root(self, i): + """ + Every lie algebra has a unique root system. + Given a root system Q, there is a subset of the + roots such that an element of Q is called a + simple root if it cannot be written as the sum + of two elements in Q. If we let D denote the + set of simple roots, then it is clear that every + element of Q can be written as a linear combination + of elements of D with all coefficients non-negative. + + In A_n the ith simple root is the root which has a 1 + in the ith position, a -1 in the (i+1)th position, + and zeroes elsewhere. + + This method returns the ith simple root for the A series. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("A4") + >>> c.simple_root(1) + [1, -1, 0, 0, 0] + + """ + + return self.basic_root(i-1, i) + + def positive_roots(self): + """ + This method generates all the positive roots of + A_n. This is half of all of the roots of A_n; + by multiplying all the positive roots by -1 we + get the negative roots. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("A3") + >>> c.positive_roots() + {1: [1, -1, 0, 0], 2: [1, 0, -1, 0], 3: [1, 0, 0, -1], 4: [0, 1, -1, 0], + 5: [0, 1, 0, -1], 6: [0, 0, 1, -1]} + """ + + n = self.n + posroots = {} + k = 0 + for i in range(0, n): + for j in range(i+1, n+1): + k += 1 + posroots[k] = self.basic_root(i, j) + return posroots + + def highest_root(self): + """ + Returns the highest weight root for A_n + """ + + return self.basic_root(0, self.n) + + def roots(self): + """ + Returns the total number of roots for A_n + """ + n = self.n + return n*(n+1) + + def cartan_matrix(self): + """ + Returns the Cartan matrix for A_n. + The Cartan matrix matrix for a Lie algebra is + generated by assigning an ordering to the simple + roots, (alpha[1], ...., alpha[l]). Then the ijth + entry of the Cartan matrix is (). + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType('A4') + >>> c.cartan_matrix() + Matrix([ + [ 2, -1, 0, 0], + [-1, 2, -1, 0], + [ 0, -1, 2, -1], + [ 0, 0, -1, 2]]) + + """ + + n = self.n + m = 2 * eye(n) + for i in range(1, n - 1): + m[i, i+1] = -1 + m[i, i-1] = -1 + m[0,1] = -1 + m[n-1, n-2] = -1 + return m + + def basis(self): + """ + Returns the number of independent generators of A_n + """ + n = self.n + return n**2 - 1 + + def lie_algebra(self): + """ + Returns the Lie algebra associated with A_n + """ + n = self.n + return "su(" + str(n + 1) + ")" + + def dynkin_diagram(self): + n = self.n + diag = "---".join("0" for i in range(1, n+1)) + "\n" + diag += " ".join(str(i) for i in range(1, n+1)) + return diag diff --git a/.venv/lib/python3.13/site-packages/sympy/liealgebras/type_b.py b/.venv/lib/python3.13/site-packages/sympy/liealgebras/type_b.py new file mode 100644 index 0000000000000000000000000000000000000000..c6ee85502261f4702769067c64021521a2bc1725 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/liealgebras/type_b.py @@ -0,0 +1,170 @@ +from .cartan_type import Standard_Cartan +from sympy.core.backend import eye + +class TypeB(Standard_Cartan): + + def __new__(cls, n): + if n < 2: + raise ValueError("n cannot be less than 2") + return Standard_Cartan.__new__(cls, "B", n) + + def dimension(self): + """Dimension of the vector space V underlying the Lie algebra + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("B3") + >>> c.dimension() + 3 + """ + + return self.n + + def basic_root(self, i, j): + """ + This is a method just to generate roots + with a 1 iin the ith position and a -1 + in the jth position. + + """ + root = [0]*self.n + root[i] = 1 + root[j] = -1 + return root + + def simple_root(self, i): + """ + Every lie algebra has a unique root system. + Given a root system Q, there is a subset of the + roots such that an element of Q is called a + simple root if it cannot be written as the sum + of two elements in Q. If we let D denote the + set of simple roots, then it is clear that every + element of Q can be written as a linear combination + of elements of D with all coefficients non-negative. + + In B_n the first n-1 simple roots are the same as the + roots in A_(n-1) (a 1 in the ith position, a -1 in + the (i+1)th position, and zeroes elsewhere). The n-th + simple root is the root with a 1 in the nth position + and zeroes elsewhere. + + This method returns the ith simple root for the B series. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("B3") + >>> c.simple_root(2) + [0, 1, -1] + + """ + n = self.n + if i < n: + return self.basic_root(i-1, i) + else: + root = [0]*self.n + root[n-1] = 1 + return root + + def positive_roots(self): + """ + This method generates all the positive roots of + A_n. This is half of all of the roots of B_n; + by multiplying all the positive roots by -1 we + get the negative roots. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("A3") + >>> c.positive_roots() + {1: [1, -1, 0, 0], 2: [1, 0, -1, 0], 3: [1, 0, 0, -1], 4: [0, 1, -1, 0], + 5: [0, 1, 0, -1], 6: [0, 0, 1, -1]} + """ + + n = self.n + posroots = {} + k = 0 + for i in range(0, n-1): + for j in range(i+1, n): + k += 1 + posroots[k] = self.basic_root(i, j) + k += 1 + root = self.basic_root(i, j) + root[j] = 1 + posroots[k] = root + + for i in range(0, n): + k += 1 + root = [0]*n + root[i] = 1 + posroots[k] = root + + return posroots + + def roots(self): + """ + Returns the total number of roots for B_n" + """ + + n = self.n + return 2*(n**2) + + def cartan_matrix(self): + """ + Returns the Cartan matrix for B_n. + The Cartan matrix matrix for a Lie algebra is + generated by assigning an ordering to the simple + roots, (alpha[1], ...., alpha[l]). Then the ijth + entry of the Cartan matrix is (). + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType('B4') + >>> c.cartan_matrix() + Matrix([ + [ 2, -1, 0, 0], + [-1, 2, -1, 0], + [ 0, -1, 2, -2], + [ 0, 0, -1, 2]]) + + """ + + n = self.n + m = 2* eye(n) + for i in range(1, n - 1): + m[i, i+1] = -1 + m[i, i-1] = -1 + m[0, 1] = -1 + m[n-2, n-1] = -2 + m[n-1, n-2] = -1 + return m + + def basis(self): + """ + Returns the number of independent generators of B_n + """ + + n = self.n + return (n**2 - n)/2 + + def lie_algebra(self): + """ + Returns the Lie algebra associated with B_n + """ + + n = self.n + return "so(" + str(2*n) + ")" + + def dynkin_diagram(self): + n = self.n + diag = "---".join("0" for i in range(1, n)) + "=>=0\n" + diag += " ".join(str(i) for i in range(1, n+1)) + return diag diff --git a/.venv/lib/python3.13/site-packages/sympy/liealgebras/type_c.py b/.venv/lib/python3.13/site-packages/sympy/liealgebras/type_c.py new file mode 100644 index 0000000000000000000000000000000000000000..615bb900b5ba9613fd02e43f476d34eef0d5d35c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/liealgebras/type_c.py @@ -0,0 +1,169 @@ +from .cartan_type import Standard_Cartan +from sympy.core.backend import eye + +class TypeC(Standard_Cartan): + + def __new__(cls, n): + if n < 3: + raise ValueError("n cannot be less than 3") + return Standard_Cartan.__new__(cls, "C", n) + + + def dimension(self): + """Dimension of the vector space V underlying the Lie algebra + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("C3") + >>> c.dimension() + 3 + """ + n = self.n + return n + + def basic_root(self, i, j): + """Generate roots with 1 in ith position and a -1 in jth position + """ + n = self.n + root = [0]*n + root[i] = 1 + root[j] = -1 + return root + + def simple_root(self, i): + """The ith simple root for the C series + + Every lie algebra has a unique root system. + Given a root system Q, there is a subset of the + roots such that an element of Q is called a + simple root if it cannot be written as the sum + of two elements in Q. If we let D denote the + set of simple roots, then it is clear that every + element of Q can be written as a linear combination + of elements of D with all coefficients non-negative. + + In C_n, the first n-1 simple roots are the same as + the roots in A_(n-1) (a 1 in the ith position, a -1 + in the (i+1)th position, and zeroes elsewhere). The + nth simple root is the root in which there is a 2 in + the nth position and zeroes elsewhere. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("C3") + >>> c.simple_root(2) + [0, 1, -1] + + """ + + n = self.n + if i < n: + return self.basic_root(i-1,i) + else: + root = [0]*self.n + root[n-1] = 2 + return root + + + def positive_roots(self): + """Generates all the positive roots of A_n + + This is half of all of the roots of C_n; by multiplying all the + positive roots by -1 we get the negative roots. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("A3") + >>> c.positive_roots() + {1: [1, -1, 0, 0], 2: [1, 0, -1, 0], 3: [1, 0, 0, -1], 4: [0, 1, -1, 0], + 5: [0, 1, 0, -1], 6: [0, 0, 1, -1]} + + """ + + n = self.n + posroots = {} + k = 0 + for i in range(0, n-1): + for j in range(i+1, n): + k += 1 + posroots[k] = self.basic_root(i, j) + k += 1 + root = self.basic_root(i, j) + root[j] = 1 + posroots[k] = root + + for i in range(0, n): + k += 1 + root = [0]*n + root[i] = 2 + posroots[k] = root + + return posroots + + def roots(self): + """ + Returns the total number of roots for C_n" + """ + + n = self.n + return 2*(n**2) + + def cartan_matrix(self): + """The Cartan matrix for C_n + + The Cartan matrix matrix for a Lie algebra is + generated by assigning an ordering to the simple + roots, (alpha[1], ...., alpha[l]). Then the ijth + entry of the Cartan matrix is (). + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType('C4') + >>> c.cartan_matrix() + Matrix([ + [ 2, -1, 0, 0], + [-1, 2, -1, 0], + [ 0, -1, 2, -1], + [ 0, 0, -2, 2]]) + + """ + + n = self.n + m = 2 * eye(n) + for i in range(1, n - 1): + m[i, i+1] = -1 + m[i, i-1] = -1 + m[0,1] = -1 + m[n-1, n-2] = -2 + return m + + + def basis(self): + """ + Returns the number of independent generators of C_n + """ + + n = self.n + return n*(2*n + 1) + + def lie_algebra(self): + """ + Returns the Lie algebra associated with C_n" + """ + + n = self.n + return "sp(" + str(2*n) + ")" + + def dynkin_diagram(self): + n = self.n + diag = "---".join("0" for i in range(1, n)) + "=<=0\n" + diag += " ".join(str(i) for i in range(1, n+1)) + return diag diff --git a/.venv/lib/python3.13/site-packages/sympy/liealgebras/type_d.py b/.venv/lib/python3.13/site-packages/sympy/liealgebras/type_d.py new file mode 100644 index 0000000000000000000000000000000000000000..9450d76e906c79e23db0ce223ed0de03d71c1199 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/liealgebras/type_d.py @@ -0,0 +1,173 @@ +from .cartan_type import Standard_Cartan +from sympy.core.backend import eye + +class TypeD(Standard_Cartan): + + def __new__(cls, n): + if n < 3: + raise ValueError("n cannot be less than 3") + return Standard_Cartan.__new__(cls, "D", n) + + + def dimension(self): + """Dmension of the vector space V underlying the Lie algebra + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("D4") + >>> c.dimension() + 4 + """ + + return self.n + + def basic_root(self, i, j): + """ + This is a method just to generate roots + with a 1 iin the ith position and a -1 + in the jth position. + + """ + + n = self.n + root = [0]*n + root[i] = 1 + root[j] = -1 + return root + + def simple_root(self, i): + """ + Every lie algebra has a unique root system. + Given a root system Q, there is a subset of the + roots such that an element of Q is called a + simple root if it cannot be written as the sum + of two elements in Q. If we let D denote the + set of simple roots, then it is clear that every + element of Q can be written as a linear combination + of elements of D with all coefficients non-negative. + + In D_n, the first n-1 simple roots are the same as + the roots in A_(n-1) (a 1 in the ith position, a -1 + in the (i+1)th position, and zeroes elsewhere). + The nth simple root is the root in which there 1s in + the nth and (n-1)th positions, and zeroes elsewhere. + + This method returns the ith simple root for the D series. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("D4") + >>> c.simple_root(2) + [0, 1, -1, 0] + + """ + + n = self.n + if i < n: + return self.basic_root(i-1, i) + else: + root = [0]*n + root[n-2] = 1 + root[n-1] = 1 + return root + + + def positive_roots(self): + """ + This method generates all the positive roots of + A_n. This is half of all of the roots of D_n + by multiplying all the positive roots by -1 we + get the negative roots. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("A3") + >>> c.positive_roots() + {1: [1, -1, 0, 0], 2: [1, 0, -1, 0], 3: [1, 0, 0, -1], 4: [0, 1, -1, 0], + 5: [0, 1, 0, -1], 6: [0, 0, 1, -1]} + """ + + n = self.n + posroots = {} + k = 0 + for i in range(0, n-1): + for j in range(i+1, n): + k += 1 + posroots[k] = self.basic_root(i, j) + k += 1 + root = self.basic_root(i, j) + root[j] = 1 + posroots[k] = root + return posroots + + def roots(self): + """ + Returns the total number of roots for D_n" + """ + + n = self.n + return 2*n*(n-1) + + def cartan_matrix(self): + """ + Returns the Cartan matrix for D_n. + The Cartan matrix matrix for a Lie algebra is + generated by assigning an ordering to the simple + roots, (alpha[1], ...., alpha[l]). Then the ijth + entry of the Cartan matrix is (). + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType('D4') + >>> c.cartan_matrix() + Matrix([ + [ 2, -1, 0, 0], + [-1, 2, -1, -1], + [ 0, -1, 2, 0], + [ 0, -1, 0, 2]]) + + """ + + n = self.n + m = 2*eye(n) + for i in range(1, n - 2): + m[i,i+1] = -1 + m[i,i-1] = -1 + m[n-2, n-3] = -1 + m[n-3, n-1] = -1 + m[n-1, n-3] = -1 + m[0, 1] = -1 + return m + + def basis(self): + """ + Returns the number of independent generators of D_n + """ + n = self.n + return n*(n-1)/2 + + def lie_algebra(self): + """ + Returns the Lie algebra associated with D_n" + """ + + n = self.n + return "so(" + str(2*n) + ")" + + def dynkin_diagram(self): + n = self.n + diag = " "*4*(n-3) + str(n-1) + "\n" + diag += " "*4*(n-3) + "0\n" + diag += " "*4*(n-3) +"|\n" + diag += " "*4*(n-3) + "|\n" + diag += "---".join("0" for i in range(1,n)) + "\n" + diag += " ".join(str(i) for i in range(1, n-1)) + " "+str(n) + return diag diff --git a/.venv/lib/python3.13/site-packages/sympy/liealgebras/type_e.py b/.venv/lib/python3.13/site-packages/sympy/liealgebras/type_e.py new file mode 100644 index 0000000000000000000000000000000000000000..3db9a820d31bff31acc58ba1592a1b10f8be53db --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/liealgebras/type_e.py @@ -0,0 +1,275 @@ +import itertools + +from .cartan_type import Standard_Cartan +from sympy.core.backend import eye, Rational +from sympy.core.singleton import S + +class TypeE(Standard_Cartan): + + def __new__(cls, n): + if n < 6 or n > 8: + raise ValueError("Invalid value of n") + return Standard_Cartan.__new__(cls, "E", n) + + def dimension(self): + """Dimension of the vector space V underlying the Lie algebra + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("E6") + >>> c.dimension() + 8 + """ + + return 8 + + def basic_root(self, i, j): + """ + This is a method just to generate roots + with a -1 in the ith position and a 1 + in the jth position. + + """ + + root = [0]*8 + root[i] = -1 + root[j] = 1 + return root + + def simple_root(self, i): + """ + Every Lie algebra has a unique root system. + Given a root system Q, there is a subset of the + roots such that an element of Q is called a + simple root if it cannot be written as the sum + of two elements in Q. If we let D denote the + set of simple roots, then it is clear that every + element of Q can be written as a linear combination + of elements of D with all coefficients non-negative. + + This method returns the ith simple root for E_n. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("E6") + >>> c.simple_root(2) + [1, 1, 0, 0, 0, 0, 0, 0] + """ + n = self.n + if i == 1: + root = [-0.5]*8 + root[0] = 0.5 + root[7] = 0.5 + return root + elif i == 2: + root = [0]*8 + root[1] = 1 + root[0] = 1 + return root + else: + if i in (7, 8) and n == 6: + raise ValueError("E6 only has six simple roots!") + if i == 8 and n == 7: + raise ValueError("E7 only has seven simple roots!") + + return self.basic_root(i - 3, i - 2) + + def positive_roots(self): + """ + This method generates all the positive roots of + A_n. This is half of all of the roots of E_n; + by multiplying all the positive roots by -1 we + get the negative roots. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("A3") + >>> c.positive_roots() + {1: [1, -1, 0, 0], 2: [1, 0, -1, 0], 3: [1, 0, 0, -1], 4: [0, 1, -1, 0], + 5: [0, 1, 0, -1], 6: [0, 0, 1, -1]} + """ + n = self.n + neghalf = Rational(-1, 2) + poshalf = S.Half + if n == 6: + posroots = {} + k = 0 + for i in range(n-1): + for j in range(i+1, n-1): + k += 1 + root = self.basic_root(i, j) + posroots[k] = root + k += 1 + root = self.basic_root(i, j) + root[i] = 1 + posroots[k] = root + + root = [poshalf, poshalf, poshalf, poshalf, poshalf, + neghalf, neghalf, poshalf] + for a, b, c, d, e in itertools.product( + range(2), range(2), range(2), range(2), range(2)): + if (a + b + c + d + e)%2 == 0: + k += 1 + if a == 1: + root[0] = neghalf + if b == 1: + root[1] = neghalf + if c == 1: + root[2] = neghalf + if d == 1: + root[3] = neghalf + if e == 1: + root[4] = neghalf + posroots[k] = root[:] + return posroots + if n == 7: + posroots = {} + k = 0 + for i in range(n-1): + for j in range(i+1, n-1): + k += 1 + root = self.basic_root(i, j) + posroots[k] = root + k += 1 + root = self.basic_root(i, j) + root[i] = 1 + posroots[k] = root + + k += 1 + posroots[k] = [0, 0, 0, 0, 0, 1, 1, 0] + root = [poshalf, poshalf, poshalf, poshalf, poshalf, + neghalf, neghalf, poshalf] + for a, b, c, d, e, f in itertools.product( + range(2), range(2), range(2), range(2), range(2), range(2)): + if (a + b + c + d + e + f)%2 == 0: + k += 1 + if a == 1: + root[0] = neghalf + if b == 1: + root[1] = neghalf + if c == 1: + root[2] = neghalf + if d == 1: + root[3] = neghalf + if e == 1: + root[4] = neghalf + if f == 1: + root[5] = poshalf + posroots[k] = root[:] + return posroots + if n == 8: + posroots = {} + k = 0 + for i in range(n): + for j in range(i+1, n): + k += 1 + root = self.basic_root(i, j) + posroots[k] = root + k += 1 + root = self.basic_root(i, j) + root[i] = 1 + posroots[k] = root + + root = [poshalf, poshalf, poshalf, poshalf, poshalf, + neghalf, neghalf, poshalf] + for a, b, c, d, e, f, g in itertools.product( + range(2), range(2), range(2), range(2), range(2), + range(2), range(2)): + if (a + b + c + d + e + f + g)%2 == 0: + k += 1 + if a == 1: + root[0] = neghalf + if b == 1: + root[1] = neghalf + if c == 1: + root[2] = neghalf + if d == 1: + root[3] = neghalf + if e == 1: + root[4] = neghalf + if f == 1: + root[5] = poshalf + if g == 1: + root[6] = poshalf + posroots[k] = root[:] + return posroots + + + + def roots(self): + """ + Returns the total number of roots of E_n + """ + + n = self.n + if n == 6: + return 72 + if n == 7: + return 126 + if n == 8: + return 240 + + + def cartan_matrix(self): + """ + Returns the Cartan matrix for G_2 + The Cartan matrix matrix for a Lie algebra is + generated by assigning an ordering to the simple + roots, (alpha[1], ...., alpha[l]). Then the ijth + entry of the Cartan matrix is (). + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType('A4') + >>> c.cartan_matrix() + Matrix([ + [ 2, -1, 0, 0], + [-1, 2, -1, 0], + [ 0, -1, 2, -1], + [ 0, 0, -1, 2]]) + + + """ + + n = self.n + m = 2*eye(n) + for i in range(3, n - 1): + m[i, i+1] = -1 + m[i, i-1] = -1 + m[0, 2] = m[2, 0] = -1 + m[1, 3] = m[3, 1] = -1 + m[2, 3] = -1 + m[n-1, n-2] = -1 + return m + + + def basis(self): + """ + Returns the number of independent generators of E_n + """ + + n = self.n + if n == 6: + return 78 + if n == 7: + return 133 + if n == 8: + return 248 + + def dynkin_diagram(self): + n = self.n + diag = " "*8 + str(2) + "\n" + diag += " "*8 + "0\n" + diag += " "*8 + "|\n" + diag += " "*8 + "|\n" + diag += "---".join("0" for i in range(1, n)) + "\n" + diag += "1 " + " ".join(str(i) for i in range(3, n+1)) + return diag diff --git a/.venv/lib/python3.13/site-packages/sympy/liealgebras/type_f.py b/.venv/lib/python3.13/site-packages/sympy/liealgebras/type_f.py new file mode 100644 index 0000000000000000000000000000000000000000..f04da557870f2cd21818cf69c454ef598e2ab65a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/liealgebras/type_f.py @@ -0,0 +1,162 @@ +from .cartan_type import Standard_Cartan +from sympy.core.backend import Matrix, Rational + + +class TypeF(Standard_Cartan): + + def __new__(cls, n): + if n != 4: + raise ValueError("n should be 4") + return Standard_Cartan.__new__(cls, "F", 4) + + def dimension(self): + """Dimension of the vector space V underlying the Lie algebra + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("F4") + >>> c.dimension() + 4 + """ + + return 4 + + + def basic_root(self, i, j): + """Generate roots with 1 in ith position and -1 in jth position + + """ + + n = self.n + root = [0]*n + root[i] = 1 + root[j] = -1 + return root + + def simple_root(self, i): + """The ith simple root of F_4 + + Every lie algebra has a unique root system. + Given a root system Q, there is a subset of the + roots such that an element of Q is called a + simple root if it cannot be written as the sum + of two elements in Q. If we let D denote the + set of simple roots, then it is clear that every + element of Q can be written as a linear combination + of elements of D with all coefficients non-negative. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("F4") + >>> c.simple_root(3) + [0, 0, 0, 1] + + """ + + if i < 3: + return self.basic_root(i-1, i) + if i == 3: + root = [0]*4 + root[3] = 1 + return root + if i == 4: + root = [Rational(-1, 2)]*4 + return root + + def positive_roots(self): + """Generate all the positive roots of A_n + + This is half of all of the roots of F_4; by multiplying all the + positive roots by -1 we get the negative roots. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("A3") + >>> c.positive_roots() + {1: [1, -1, 0, 0], 2: [1, 0, -1, 0], 3: [1, 0, 0, -1], 4: [0, 1, -1, 0], + 5: [0, 1, 0, -1], 6: [0, 0, 1, -1]} + + """ + + n = self.n + posroots = {} + k = 0 + for i in range(0, n-1): + for j in range(i+1, n): + k += 1 + posroots[k] = self.basic_root(i, j) + k += 1 + root = self.basic_root(i, j) + root[j] = 1 + posroots[k] = root + + for i in range(0, n): + k += 1 + root = [0]*n + root[i] = 1 + posroots[k] = root + + k += 1 + root = [Rational(1, 2)]*n + posroots[k] = root + for i in range(1, 4): + k += 1 + root = [Rational(1, 2)]*n + root[i] = Rational(-1, 2) + posroots[k] = root + + posroots[k+1] = [Rational(1, 2), Rational(1, 2), Rational(-1, 2), Rational(-1, 2)] + posroots[k+2] = [Rational(1, 2), Rational(-1, 2), Rational(1, 2), Rational(-1, 2)] + posroots[k+3] = [Rational(1, 2), Rational(-1, 2), Rational(-1, 2), Rational(1, 2)] + posroots[k+4] = [Rational(1, 2), Rational(-1, 2), Rational(-1, 2), Rational(-1, 2)] + + return posroots + + + def roots(self): + """ + Returns the total number of roots for F_4 + """ + return 48 + + def cartan_matrix(self): + """The Cartan matrix for F_4 + + The Cartan matrix matrix for a Lie algebra is + generated by assigning an ordering to the simple + roots, (alpha[1], ...., alpha[l]). Then the ijth + entry of the Cartan matrix is (). + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType('A4') + >>> c.cartan_matrix() + Matrix([ + [ 2, -1, 0, 0], + [-1, 2, -1, 0], + [ 0, -1, 2, -1], + [ 0, 0, -1, 2]]) + """ + + m = Matrix( 4, 4, [2, -1, 0, 0, -1, 2, -2, 0, 0, + -1, 2, -1, 0, 0, -1, 2]) + return m + + def basis(self): + """ + Returns the number of independent generators of F_4 + """ + return 52 + + def dynkin_diagram(self): + diag = "0---0=>=0---0\n" + diag += " ".join(str(i) for i in range(1, 5)) + return diag diff --git a/.venv/lib/python3.13/site-packages/sympy/liealgebras/type_g.py b/.venv/lib/python3.13/site-packages/sympy/liealgebras/type_g.py new file mode 100644 index 0000000000000000000000000000000000000000..014409cf5ed966b53c596b14e0073e89ceee05b6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/liealgebras/type_g.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- + +from .cartan_type import Standard_Cartan +from sympy.core.backend import Matrix + +class TypeG(Standard_Cartan): + + def __new__(cls, n): + if n != 2: + raise ValueError("n should be 2") + return Standard_Cartan.__new__(cls, "G", 2) + + + def dimension(self): + """Dimension of the vector space V underlying the Lie algebra + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("G2") + >>> c.dimension() + 3 + """ + return 3 + + def simple_root(self, i): + """The ith simple root of G_2 + + Every lie algebra has a unique root system. + Given a root system Q, there is a subset of the + roots such that an element of Q is called a + simple root if it cannot be written as the sum + of two elements in Q. If we let D denote the + set of simple roots, then it is clear that every + element of Q can be written as a linear combination + of elements of D with all coefficients non-negative. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("G2") + >>> c.simple_root(1) + [0, 1, -1] + + """ + if i == 1: + return [0, 1, -1] + else: + return [1, -2, 1] + + def positive_roots(self): + """Generate all the positive roots of A_n + + This is half of all of the roots of A_n; by multiplying all the + positive roots by -1 we get the negative roots. + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("A3") + >>> c.positive_roots() + {1: [1, -1, 0, 0], 2: [1, 0, -1, 0], 3: [1, 0, 0, -1], 4: [0, 1, -1, 0], + 5: [0, 1, 0, -1], 6: [0, 0, 1, -1]} + + """ + + roots = {1: [0, 1, -1], 2: [1, -2, 1], 3: [1, -1, 0], 4: [1, 0, 1], + 5: [1, 1, -2], 6: [2, -1, -1]} + return roots + + def roots(self): + """ + Returns the total number of roots of G_2" + """ + return 12 + + def cartan_matrix(self): + """The Cartan matrix for G_2 + + The Cartan matrix matrix for a Lie algebra is + generated by assigning an ordering to the simple + roots, (alpha[1], ...., alpha[l]). Then the ijth + entry of the Cartan matrix is (). + + Examples + ======== + + >>> from sympy.liealgebras.cartan_type import CartanType + >>> c = CartanType("G2") + >>> c.cartan_matrix() + Matrix([ + [ 2, -1], + [-3, 2]]) + + """ + + m = Matrix( 2, 2, [2, -1, -3, 2]) + return m + + def basis(self): + """ + Returns the number of independent generators of G_2 + """ + return 14 + + def dynkin_diagram(self): + diag = "0≡<≡0\n1 2" + return diag diff --git a/.venv/lib/python3.13/site-packages/sympy/liealgebras/weyl_group.py b/.venv/lib/python3.13/site-packages/sympy/liealgebras/weyl_group.py new file mode 100644 index 0000000000000000000000000000000000000000..15ff70b6f1fc4649268a38ee13e1f717a1c9f5fa --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/liealgebras/weyl_group.py @@ -0,0 +1,403 @@ +# -*- coding: utf-8 -*- + +from .cartan_type import CartanType +from mpmath import fac +from sympy.core.backend import Matrix, eye, Rational, igcd +from sympy.core.basic import Atom + +class WeylGroup(Atom): + + """ + For each semisimple Lie group, we have a Weyl group. It is a subgroup of + the isometry group of the root system. Specifically, it's the subgroup + that is generated by reflections through the hyperplanes orthogonal to + the roots. Therefore, Weyl groups are reflection groups, and so a Weyl + group is a finite Coxeter group. + + """ + + def __new__(cls, cartantype): + obj = Atom.__new__(cls) + obj.cartan_type = CartanType(cartantype) + return obj + + def generators(self): + """ + This method creates the generating reflections of the Weyl group for + a given Lie algebra. For a Lie algebra of rank n, there are n + different generating reflections. This function returns them as + a list. + + Examples + ======== + + >>> from sympy.liealgebras.weyl_group import WeylGroup + >>> c = WeylGroup("F4") + >>> c.generators() + ['r1', 'r2', 'r3', 'r4'] + """ + n = self.cartan_type.rank() + generators = [] + for i in range(1, n+1): + reflection = "r"+str(i) + generators.append(reflection) + return generators + + def group_order(self): + """ + This method returns the order of the Weyl group. + For types A, B, C, D, and E the order depends on + the rank of the Lie algebra. For types F and G, + the order is fixed. + + Examples + ======== + + >>> from sympy.liealgebras.weyl_group import WeylGroup + >>> c = WeylGroup("D4") + >>> c.group_order() + 192.0 + """ + n = self.cartan_type.rank() + if self.cartan_type.series == "A": + return fac(n+1) + + if self.cartan_type.series in ("B", "C"): + return fac(n)*(2**n) + + if self.cartan_type.series == "D": + return fac(n)*(2**(n-1)) + + if self.cartan_type.series == "E": + if n == 6: + return 51840 + if n == 7: + return 2903040 + if n == 8: + return 696729600 + if self.cartan_type.series == "F": + return 1152 + + if self.cartan_type.series == "G": + return 12 + + def group_name(self): + """ + This method returns some general information about the Weyl group for + a given Lie algebra. It returns the name of the group and the elements + it acts on, if relevant. + """ + n = self.cartan_type.rank() + if self.cartan_type.series == "A": + return "S"+str(n+1) + ": the symmetric group acting on " + str(n+1) + " elements." + + if self.cartan_type.series in ("B", "C"): + return "The hyperoctahedral group acting on " + str(2*n) + " elements." + + if self.cartan_type.series == "D": + return "The symmetry group of the " + str(n) + "-dimensional demihypercube." + + if self.cartan_type.series == "E": + if n == 6: + return "The symmetry group of the 6-polytope." + + if n == 7: + return "The symmetry group of the 7-polytope." + + if n == 8: + return "The symmetry group of the 8-polytope." + + if self.cartan_type.series == "F": + return "The symmetry group of the 24-cell, or icositetrachoron." + + if self.cartan_type.series == "G": + return "D6, the dihedral group of order 12, and symmetry group of the hexagon." + + def element_order(self, weylelt): + """ + This method returns the order of a given Weyl group element, which should + be specified by the user in the form of products of the generating + reflections, i.e. of the form r1*r2 etc. + + For types A-F, this method current works by taking the matrix form of + the specified element, and then finding what power of the matrix is the + identity. It then returns this power. + + Examples + ======== + + >>> from sympy.liealgebras.weyl_group import WeylGroup + >>> b = WeylGroup("B4") + >>> b.element_order('r1*r4*r2') + 4 + """ + n = self.cartan_type.rank() + if self.cartan_type.series == "A": + a = self.matrix_form(weylelt) + order = 1 + while a != eye(n+1): + a *= self.matrix_form(weylelt) + order += 1 + return order + + if self.cartan_type.series == "D": + a = self.matrix_form(weylelt) + order = 1 + while a != eye(n): + a *= self.matrix_form(weylelt) + order += 1 + return order + + if self.cartan_type.series == "E": + a = self.matrix_form(weylelt) + order = 1 + while a != eye(8): + a *= self.matrix_form(weylelt) + order += 1 + return order + + if self.cartan_type.series == "G": + elts = list(weylelt) + reflections = elts[1::3] + m = self.delete_doubles(reflections) + while self.delete_doubles(m) != m: + m = self.delete_doubles(m) + reflections = m + if len(reflections) % 2 == 1: + return 2 + + elif len(reflections) == 0: + return 1 + + else: + if len(reflections) == 1: + return 2 + else: + m = len(reflections) // 2 + lcm = (6 * m)/ igcd(m, 6) + order = lcm / m + return order + + + if self.cartan_type.series == 'F': + a = self.matrix_form(weylelt) + order = 1 + while a != eye(4): + a *= self.matrix_form(weylelt) + order += 1 + return order + + + if self.cartan_type.series in ("B", "C"): + a = self.matrix_form(weylelt) + order = 1 + while a != eye(n): + a *= self.matrix_form(weylelt) + order += 1 + return order + + def delete_doubles(self, reflections): + """ + This is a helper method for determining the order of an element in the + Weyl group of G2. It takes a Weyl element and if repeated simple reflections + in it, it deletes them. + """ + counter = 0 + copy = list(reflections) + for elt in copy: + if counter < len(copy)-1: + if copy[counter + 1] == elt: + del copy[counter] + del copy[counter] + counter += 1 + + + return copy + + + def matrix_form(self, weylelt): + """ + This method takes input from the user in the form of products of the + generating reflections, and returns the matrix corresponding to the + element of the Weyl group. Since each element of the Weyl group is + a reflection of some type, there is a corresponding matrix representation. + This method uses the standard representation for all the generating + reflections. + + Examples + ======== + + >>> from sympy.liealgebras.weyl_group import WeylGroup + >>> f = WeylGroup("F4") + >>> f.matrix_form('r2*r3') + Matrix([ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 0, -1], + [0, 0, 1, 0]]) + + """ + elts = list(weylelt) + reflections = elts[1::3] + n = self.cartan_type.rank() + if self.cartan_type.series == 'A': + matrixform = eye(n+1) + for elt in reflections: + a = int(elt) + mat = eye(n+1) + mat[a-1, a-1] = 0 + mat[a-1, a] = 1 + mat[a, a-1] = 1 + mat[a, a] = 0 + matrixform *= mat + return matrixform + + if self.cartan_type.series == 'D': + matrixform = eye(n) + for elt in reflections: + a = int(elt) + mat = eye(n) + if a < n: + mat[a-1, a-1] = 0 + mat[a-1, a] = 1 + mat[a, a-1] = 1 + mat[a, a] = 0 + matrixform *= mat + else: + mat[n-2, n-1] = -1 + mat[n-2, n-2] = 0 + mat[n-1, n-2] = -1 + mat[n-1, n-1] = 0 + matrixform *= mat + return matrixform + + if self.cartan_type.series == 'G': + matrixform = eye(3) + for elt in reflections: + a = int(elt) + if a == 1: + gen1 = Matrix([[1, 0, 0], [0, 0, 1], [0, 1, 0]]) + matrixform *= gen1 + else: + gen2 = Matrix([[Rational(2, 3), Rational(2, 3), Rational(-1, 3)], + [Rational(2, 3), Rational(-1, 3), Rational(2, 3)], + [Rational(-1, 3), Rational(2, 3), Rational(2, 3)]]) + matrixform *= gen2 + return matrixform + + if self.cartan_type.series == 'F': + matrixform = eye(4) + for elt in reflections: + a = int(elt) + if a == 1: + mat = Matrix([[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) + matrixform *= mat + elif a == 2: + mat = Matrix([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]) + matrixform *= mat + elif a == 3: + mat = Matrix([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, -1]]) + matrixform *= mat + else: + + mat = Matrix([[Rational(1, 2), Rational(1, 2), Rational(1, 2), Rational(1, 2)], + [Rational(1, 2), Rational(1, 2), Rational(-1, 2), Rational(-1, 2)], + [Rational(1, 2), Rational(-1, 2), Rational(1, 2), Rational(-1, 2)], + [Rational(1, 2), Rational(-1, 2), Rational(-1, 2), Rational(1, 2)]]) + matrixform *= mat + return matrixform + + if self.cartan_type.series == 'E': + matrixform = eye(8) + for elt in reflections: + a = int(elt) + if a == 1: + mat = Matrix([[Rational(3, 4), Rational(1, 4), Rational(1, 4), Rational(1, 4), + Rational(1, 4), Rational(1, 4), Rational(1, 4), Rational(-1, 4)], + [Rational(1, 4), Rational(3, 4), Rational(-1, 4), Rational(-1, 4), + Rational(-1, 4), Rational(-1, 4), Rational(1, 4), Rational(-1, 4)], + [Rational(1, 4), Rational(-1, 4), Rational(3, 4), Rational(-1, 4), + Rational(-1, 4), Rational(-1, 4), Rational(-1, 4), Rational(1, 4)], + [Rational(1, 4), Rational(-1, 4), Rational(-1, 4), Rational(3, 4), + Rational(-1, 4), Rational(-1, 4), Rational(-1, 4), Rational(1, 4)], + [Rational(1, 4), Rational(-1, 4), Rational(-1, 4), Rational(-1, 4), + Rational(3, 4), Rational(-1, 4), Rational(-1, 4), Rational(1, 4)], + [Rational(1, 4), Rational(-1, 4), Rational(-1, 4), Rational(-1, 4), + Rational(-1, 4), Rational(3, 4), Rational(-1, 4), Rational(1, 4)], + [Rational(1, 4), Rational(-1, 4), Rational(-1, 4), Rational(-1, 4), + Rational(-1, 4), Rational(-1, 4), Rational(-3, 4), Rational(1, 4)], + [Rational(1, 4), Rational(-1, 4), Rational(-1, 4), Rational(-1, 4), + Rational(-1, 4), Rational(-1, 4), Rational(-1, 4), Rational(3, 4)]]) + matrixform *= mat + elif a == 2: + mat = eye(8) + mat[0, 0] = 0 + mat[0, 1] = -1 + mat[1, 0] = -1 + mat[1, 1] = 0 + matrixform *= mat + else: + mat = eye(8) + mat[a-3, a-3] = 0 + mat[a-3, a-2] = 1 + mat[a-2, a-3] = 1 + mat[a-2, a-2] = 0 + matrixform *= mat + return matrixform + + + if self.cartan_type.series in ("B", "C"): + matrixform = eye(n) + for elt in reflections: + a = int(elt) + mat = eye(n) + if a == 1: + mat[0, 0] = -1 + matrixform *= mat + else: + mat[a - 2, a - 2] = 0 + mat[a-2, a-1] = 1 + mat[a - 1, a - 2] = 1 + mat[a -1, a - 1] = 0 + matrixform *= mat + return matrixform + + + + def coxeter_diagram(self): + """ + This method returns the Coxeter diagram corresponding to a Weyl group. + The Coxeter diagram can be obtained from a Lie algebra's Dynkin diagram + by deleting all arrows; the Coxeter diagram is the undirected graph. + The vertices of the Coxeter diagram represent the generating reflections + of the Weyl group, $s_i$. An edge is drawn between $s_i$ and $s_j$ if the order + $m(i, j)$ of $s_is_j$ is greater than two. If there is one edge, the order + $m(i, j)$ is 3. If there are two edges, the order $m(i, j)$ is 4, and if there + are three edges, the order $m(i, j)$ is 6. + + Examples + ======== + + >>> from sympy.liealgebras.weyl_group import WeylGroup + >>> c = WeylGroup("B3") + >>> print(c.coxeter_diagram()) + 0---0===0 + 1 2 3 + """ + n = self.cartan_type.rank() + if self.cartan_type.series in ("A", "D", "E"): + return self.cartan_type.dynkin_diagram() + + if self.cartan_type.series in ("B", "C"): + diag = "---".join("0" for i in range(1, n)) + "===0\n" + diag += " ".join(str(i) for i in range(1, n+1)) + return diag + + if self.cartan_type.series == "F": + diag = "0---0===0---0\n" + diag += " ".join(str(i) for i in range(1, 5)) + return diag + + if self.cartan_type.series == "G": + diag = "0≡≡≡0\n1 2" + return diag diff --git a/.venv/lib/python3.13/site-packages/sympy/logic/__init__.py b/.venv/lib/python3.13/site-packages/sympy/logic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb26903a384e9df3a0f02a92c488c5442cee1486 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/logic/__init__.py @@ -0,0 +1,12 @@ +from .boolalg import (to_cnf, to_dnf, to_nnf, And, Or, Not, Xor, Nand, Nor, Implies, + Equivalent, ITE, POSform, SOPform, simplify_logic, bool_map, true, false, + gateinputcount) +from .inference import satisfiable + +__all__ = [ + 'to_cnf', 'to_dnf', 'to_nnf', 'And', 'Or', 'Not', 'Xor', 'Nand', 'Nor', + 'Implies', 'Equivalent', 'ITE', 'POSform', 'SOPform', 'simplify_logic', + 'bool_map', 'true', 'false', 'gateinputcount', + + 'satisfiable', +] diff --git a/.venv/lib/python3.13/site-packages/sympy/logic/boolalg.py b/.venv/lib/python3.13/site-packages/sympy/logic/boolalg.py new file mode 100644 index 0000000000000000000000000000000000000000..8e11a9b6361ac5d7e355d5d4fb176d8df443e07e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/logic/boolalg.py @@ -0,0 +1,3587 @@ +""" +Boolean algebra module for SymPy +""" + +from __future__ import annotations +from typing import TYPE_CHECKING, overload, Any +from collections.abc import Iterable, Mapping + +from collections import defaultdict +from itertools import chain, combinations, product, permutations +from sympy.core.add import Add +from sympy.core.basic import Basic +from sympy.core.cache import cacheit +from sympy.core.containers import Tuple +from sympy.core.decorators import sympify_method_args, sympify_return +from sympy.core.function import Application, Derivative +from sympy.core.kind import BooleanKind, NumberKind +from sympy.core.numbers import Number +from sympy.core.operations import LatticeOp +from sympy.core.singleton import Singleton, S +from sympy.core.sorting import ordered +from sympy.core.sympify import _sympy_converter, _sympify, sympify +from sympy.utilities.iterables import sift, ibin +from sympy.utilities.misc import filldedent + + +def as_Boolean(e): + """Like ``bool``, return the Boolean value of an expression, e, + which can be any instance of :py:class:`~.Boolean` or ``bool``. + + Examples + ======== + + >>> from sympy import true, false, nan + >>> from sympy.logic.boolalg import as_Boolean + >>> from sympy.abc import x + >>> as_Boolean(0) is false + True + >>> as_Boolean(1) is true + True + >>> as_Boolean(x) + x + >>> as_Boolean(2) + Traceback (most recent call last): + ... + TypeError: expecting bool or Boolean, not `2`. + >>> as_Boolean(nan) + Traceback (most recent call last): + ... + TypeError: expecting bool or Boolean, not `nan`. + + """ + from sympy.core.symbol import Symbol + if e == True: + return true + if e == False: + return false + if isinstance(e, Symbol): + z = e.is_zero + if z is None: + return e + return false if z else true + if isinstance(e, Boolean): + return e + raise TypeError('expecting bool or Boolean, not `%s`.' % e) + + +@sympify_method_args +class Boolean(Basic): + """A Boolean object is an object for which logic operations make sense.""" + + __slots__ = () + + kind = BooleanKind + + if TYPE_CHECKING: + + def __new__(cls, *args: Basic | complex) -> Boolean: + ... + + @overload # type: ignore + def subs(self, arg1: Mapping[Basic | complex, Boolean | complex], arg2: None=None) -> Boolean: ... + @overload + def subs(self, arg1: Iterable[tuple[Basic | complex, Boolean | complex]], arg2: None=None, **kwargs: Any) -> Boolean: ... + @overload + def subs(self, arg1: Boolean | complex, arg2: Boolean | complex) -> Boolean: ... + @overload + def subs(self, arg1: Mapping[Basic | complex, Basic | complex], arg2: None=None, **kwargs: Any) -> Basic: ... + @overload + def subs(self, arg1: Iterable[tuple[Basic | complex, Basic | complex]], arg2: None=None, **kwargs: Any) -> Basic: ... + @overload + def subs(self, arg1: Basic | complex, arg2: Basic | complex, **kwargs: Any) -> Basic: ... + + def subs(self, arg1: Mapping[Basic | complex, Basic | complex] | Basic | complex, # type: ignore + arg2: Basic | complex | None = None, **kwargs: Any) -> Basic: + ... + + def simplify(self, **kwargs) -> Boolean: + ... + + @sympify_return([('other', 'Boolean')], NotImplemented) + def __and__(self, other): + return And(self, other) + + __rand__ = __and__ + + @sympify_return([('other', 'Boolean')], NotImplemented) + def __or__(self, other): + return Or(self, other) + + __ror__ = __or__ + + def __invert__(self): + """Overloading for ~""" + return Not(self) + + @sympify_return([('other', 'Boolean')], NotImplemented) + def __rshift__(self, other): + return Implies(self, other) + + @sympify_return([('other', 'Boolean')], NotImplemented) + def __lshift__(self, other): + return Implies(other, self) + + __rrshift__ = __lshift__ + __rlshift__ = __rshift__ + + @sympify_return([('other', 'Boolean')], NotImplemented) + def __xor__(self, other): + return Xor(self, other) + + __rxor__ = __xor__ + + def equals(self, other): + """ + Returns ``True`` if the given formulas have the same truth table. + For two formulas to be equal they must have the same literals. + + Examples + ======== + + >>> from sympy.abc import A, B, C + >>> from sympy import And, Or, Not + >>> (A >> B).equals(~B >> ~A) + True + >>> Not(And(A, B, C)).equals(And(Not(A), Not(B), Not(C))) + False + >>> Not(And(A, Not(A))).equals(Or(B, Not(B))) + False + + """ + from sympy.logic.inference import satisfiable + from sympy.core.relational import Relational + + if self.has(Relational) or other.has(Relational): + raise NotImplementedError('handling of relationals') + return self.atoms() == other.atoms() and \ + not satisfiable(Not(Equivalent(self, other))) + + def to_nnf(self, simplify=True): + # override where necessary + return self + + def as_set(self): + """ + Rewrites Boolean expression in terms of real sets. + + Examples + ======== + + >>> from sympy import Symbol, Eq, Or, And + >>> x = Symbol('x', real=True) + >>> Eq(x, 0).as_set() + {0} + >>> (x > 0).as_set() + Interval.open(0, oo) + >>> And(-2 < x, x < 2).as_set() + Interval.open(-2, 2) + >>> Or(x < -2, 2 < x).as_set() + Union(Interval.open(-oo, -2), Interval.open(2, oo)) + + """ + from sympy.calculus.util import periodicity + from sympy.core.relational import Relational + + free = self.free_symbols + if len(free) == 1: + x = free.pop() + if x.kind is NumberKind: + reps = {} + for r in self.atoms(Relational): + if periodicity(r, x) not in (0, None): + s = r._eval_as_set() + if s in (S.EmptySet, S.UniversalSet, S.Reals): + reps[r] = s.as_relational(x) + continue + raise NotImplementedError(filldedent(''' + as_set is not implemented for relationals + with periodic solutions + ''')) + new = self.subs(reps) + if new.func != self.func: + return new.as_set() # restart with new obj + else: + return new._eval_as_set() + + return self._eval_as_set() + else: + raise NotImplementedError("Sorry, as_set has not yet been" + " implemented for multivariate" + " expressions") + + @property + def binary_symbols(self): + from sympy.core.relational import Eq, Ne + return set().union(*[i.binary_symbols for i in self.args + if i.is_Boolean or i.is_Symbol + or isinstance(i, (Eq, Ne))]) + + def _eval_refine(self, assumptions): + from sympy.assumptions import ask + ret = ask(self, assumptions) + if ret is True: + return true + elif ret is False: + return false + return None + + +class BooleanAtom(Boolean): + """ + Base class of :py:class:`~.BooleanTrue` and :py:class:`~.BooleanFalse`. + """ + is_Boolean = True + is_Atom = True + _op_priority = 11 # higher than Expr + + def simplify(self, *a, **kw): + return self + + def expand(self, *a, **kw): + return self + + @property + def canonical(self): + return self + + def _noop(self, other=None): + raise TypeError('BooleanAtom not allowed in this context.') + + __add__ = _noop + __radd__ = _noop + __sub__ = _noop + __rsub__ = _noop + __mul__ = _noop + __rmul__ = _noop + __pow__ = _noop + __rpow__ = _noop + __truediv__ = _noop + __rtruediv__ = _noop + __mod__ = _noop + __rmod__ = _noop + _eval_power = _noop + + def __lt__(self, other): + raise TypeError(filldedent(''' + A Boolean argument can only be used in + Eq and Ne; all other relationals expect + real expressions. + ''')) + + __le__ = __lt__ + __gt__ = __lt__ + __ge__ = __lt__ + # \\\ + + def _eval_simplify(self, **kwargs): + return self + + +class BooleanTrue(BooleanAtom, metaclass=Singleton): + """ + SymPy version of ``True``, a singleton that can be accessed via ``S.true``. + + This is the SymPy version of ``True``, for use in the logic module. The + primary advantage of using ``true`` instead of ``True`` is that shorthand Boolean + operations like ``~`` and ``>>`` will work as expected on this class, whereas with + True they act bitwise on 1. Functions in the logic module will return this + class when they evaluate to true. + + Notes + ===== + + There is liable to be some confusion as to when ``True`` should + be used and when ``S.true`` should be used in various contexts + throughout SymPy. An important thing to remember is that + ``sympify(True)`` returns ``S.true``. This means that for the most + part, you can just use ``True`` and it will automatically be converted + to ``S.true`` when necessary, similar to how you can generally use 1 + instead of ``S.One``. + + The rule of thumb is: + + "If the boolean in question can be replaced by an arbitrary symbolic + ``Boolean``, like ``Or(x, y)`` or ``x > 1``, use ``S.true``. + Otherwise, use ``True``" + + In other words, use ``S.true`` only on those contexts where the + boolean is being used as a symbolic representation of truth. + For example, if the object ends up in the ``.args`` of any expression, + then it must necessarily be ``S.true`` instead of ``True``, as + elements of ``.args`` must be ``Basic``. On the other hand, + ``==`` is not a symbolic operation in SymPy, since it always returns + ``True`` or ``False``, and does so in terms of structural equality + rather than mathematical, so it should return ``True``. The assumptions + system should use ``True`` and ``False``. Aside from not satisfying + the above rule of thumb, the assumptions system uses a three-valued logic + (``True``, ``False``, ``None``), whereas ``S.true`` and ``S.false`` + represent a two-valued logic. When in doubt, use ``True``. + + "``S.true == True is True``." + + While "``S.true is True``" is ``False``, "``S.true == True``" + is ``True``, so if there is any doubt over whether a function or + expression will return ``S.true`` or ``True``, just use ``==`` + instead of ``is`` to do the comparison, and it will work in either + case. Finally, for boolean flags, it's better to just use ``if x`` + instead of ``if x is True``. To quote PEP 8: + + Do not compare boolean values to ``True`` or ``False`` + using ``==``. + + * Yes: ``if greeting:`` + * No: ``if greeting == True:`` + * Worse: ``if greeting is True:`` + + Examples + ======== + + >>> from sympy import sympify, true, false, Or + >>> sympify(True) + True + >>> _ is True, _ is true + (False, True) + + >>> Or(true, false) + True + >>> _ is true + True + + Python operators give a boolean result for true but a + bitwise result for True + + >>> ~true, ~True # doctest: +SKIP + (False, -2) + >>> true >> true, True >> True + (True, 0) + + See Also + ======== + + sympy.logic.boolalg.BooleanFalse + + """ + def __bool__(self): + return True + + def __hash__(self): + return hash(True) + + def __eq__(self, other): + if other is True: + return True + if other is False: + return False + return super().__eq__(other) + + @property + def negated(self): + return false + + def as_set(self): + """ + Rewrite logic operators and relationals in terms of real sets. + + Examples + ======== + + >>> from sympy import true + >>> true.as_set() + UniversalSet + + """ + return S.UniversalSet + + +class BooleanFalse(BooleanAtom, metaclass=Singleton): + """ + SymPy version of ``False``, a singleton that can be accessed via ``S.false``. + + This is the SymPy version of ``False``, for use in the logic module. The + primary advantage of using ``false`` instead of ``False`` is that shorthand + Boolean operations like ``~`` and ``>>`` will work as expected on this class, + whereas with ``False`` they act bitwise on 0. Functions in the logic module + will return this class when they evaluate to false. + + Notes + ====== + + See the notes section in :py:class:`sympy.logic.boolalg.BooleanTrue` + + Examples + ======== + + >>> from sympy import sympify, true, false, Or + >>> sympify(False) + False + >>> _ is False, _ is false + (False, True) + + >>> Or(true, false) + True + >>> _ is true + True + + Python operators give a boolean result for false but a + bitwise result for False + + >>> ~false, ~False # doctest: +SKIP + (True, -1) + >>> false >> false, False >> False + (True, 0) + + See Also + ======== + + sympy.logic.boolalg.BooleanTrue + + """ + def __bool__(self): + return False + + def __hash__(self): + return hash(False) + + def __eq__(self, other): + if other is True: + return False + if other is False: + return True + return super().__eq__(other) + + @property + def negated(self): + return true + + def as_set(self): + """ + Rewrite logic operators and relationals in terms of real sets. + + Examples + ======== + + >>> from sympy import false + >>> false.as_set() + EmptySet + """ + return S.EmptySet + + +true = BooleanTrue() +false = BooleanFalse() +# We want S.true and S.false to work, rather than S.BooleanTrue and +# S.BooleanFalse, but making the class and instance names the same causes some +# major issues (like the inability to import the class directly from this +# file). +S.true = true +S.false = false + +_sympy_converter[bool] = lambda x: true if x else false + + +class BooleanFunction(Application, Boolean): + """Boolean function is a function that lives in a boolean space + It is used as base class for :py:class:`~.And`, :py:class:`~.Or`, + :py:class:`~.Not`, etc. + """ + is_Boolean = True + + def _eval_simplify(self, **kwargs): + rv = simplify_univariate(self) + if not isinstance(rv, BooleanFunction): + return rv.simplify(**kwargs) + rv = rv.func(*[a.simplify(**kwargs) for a in rv.args]) + return simplify_logic(rv) + + def simplify(self, **kwargs): + from sympy.simplify.simplify import simplify + return simplify(self, **kwargs) + + def __lt__(self, other): + raise TypeError(filldedent(''' + A Boolean argument can only be used in + Eq and Ne; all other relationals expect + real expressions. + ''')) + __le__ = __lt__ + __ge__ = __lt__ + __gt__ = __lt__ + + @classmethod + def binary_check_and_simplify(self, *args): + return [as_Boolean(i) for i in args] + + def to_nnf(self, simplify=True): + return self._to_nnf(*self.args, simplify=simplify) + + def to_anf(self, deep=True): + return self._to_anf(*self.args, deep=deep) + + @classmethod + def _to_nnf(cls, *args, **kwargs): + simplify = kwargs.get('simplify', True) + argset = set() + for arg in args: + if not is_literal(arg): + arg = arg.to_nnf(simplify) + if simplify: + if isinstance(arg, cls): + arg = arg.args + else: + arg = (arg,) + for a in arg: + if Not(a) in argset: + return cls.zero + argset.add(a) + else: + argset.add(arg) + return cls(*argset) + + @classmethod + def _to_anf(cls, *args, **kwargs): + deep = kwargs.get('deep', True) + new_args = [] + for arg in args: + if deep: + if not is_literal(arg) or isinstance(arg, Not): + arg = arg.to_anf(deep=deep) + new_args.append(arg) + return cls(*new_args, remove_true=False) + + # the diff method below is copied from Expr class + def diff(self, *symbols, **assumptions): + assumptions.setdefault("evaluate", True) + return Derivative(self, *symbols, **assumptions) + + def _eval_derivative(self, x): + if x in self.binary_symbols: + from sympy.core.relational import Eq + from sympy.functions.elementary.piecewise import Piecewise + return Piecewise( + (0, Eq(self.subs(x, 0), self.subs(x, 1))), + (1, True)) + elif x in self.free_symbols: + # not implemented, see https://www.encyclopediaofmath.org/ + # index.php/Boolean_differential_calculus + pass + else: + return S.Zero + + +class And(LatticeOp, BooleanFunction): + """ + Logical AND function. + + It evaluates its arguments in order, returning false immediately + when an argument is false and true if they are all true. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import And + >>> x & y + x & y + + Notes + ===== + + The ``&`` operator is provided as a convenience, but note that its use + here is different from its normal use in Python, which is bitwise + and. Hence, ``And(a, b)`` and ``a & b`` will produce different results if + ``a`` and ``b`` are integers. + + >>> And(x, y).subs(x, 1) + y + + """ + zero = false + identity = true + + nargs = None + + if TYPE_CHECKING: + + def __new__(cls, *args: Boolean | bool) -> Boolean: # type: ignore + ... + + @property + def args(self) -> tuple[Boolean, ...]: + ... + + @classmethod + def _new_args_filter(cls, args): + args = BooleanFunction.binary_check_and_simplify(*args) + args = LatticeOp._new_args_filter(args, And) + newargs = [] + rel = set() + for x in ordered(args): + if x.is_Relational: + c = x.canonical + if c in rel: + continue + elif c.negated.canonical in rel: + return [false] + else: + rel.add(c) + newargs.append(x) + return newargs + + def _eval_subs(self, old, new): + args = [] + bad = None + for i in self.args: + try: + i = i.subs(old, new) + except TypeError: + # store TypeError + if bad is None: + bad = i + continue + if i == False: + return false + elif i != True: + args.append(i) + if bad is not None: + # let it raise + bad.subs(old, new) + # If old is And, replace the parts of the arguments with new if all + # are there + if isinstance(old, And): + old_set = set(old.args) + if old_set.issubset(args): + args = set(args) - old_set + args.add(new) + + return self.func(*args) + + def _eval_simplify(self, **kwargs): + from sympy.core.relational import Equality, Relational + from sympy.solvers.solveset import linear_coeffs + # standard simplify + rv = super()._eval_simplify(**kwargs) + if not isinstance(rv, And): + return rv + + # simplify args that are equalities involving + # symbols so x == 0 & x == y -> x==0 & y == 0 + Rel, nonRel = sift(rv.args, lambda i: isinstance(i, Relational), + binary=True) + if not Rel: + return rv + eqs, other = sift(Rel, lambda i: isinstance(i, Equality), binary=True) + + measure = kwargs['measure'] + if eqs: + ratio = kwargs['ratio'] + reps = {} + sifted = {} + # group by length of free symbols + sifted = sift(ordered([ + (i.free_symbols, i) for i in eqs]), + lambda x: len(x[0])) + eqs = [] + nonlineqs = [] + while 1 in sifted: + for free, e in sifted.pop(1): + x = free.pop() + if (e.lhs != x or x in e.rhs.free_symbols) and x not in reps: + try: + m, b = linear_coeffs( + Add(e.lhs, -e.rhs, evaluate=False), x) + enew = e.func(x, -b/m) + if measure(enew) <= ratio*measure(e): + e = enew + else: + eqs.append(e) + continue + except ValueError: + pass + if x in reps: + eqs.append(e.subs(x, reps[x])) + elif e.lhs == x and x not in e.rhs.free_symbols: + reps[x] = e.rhs + eqs.append(e) + else: + # x is not yet identified, but may be later + nonlineqs.append(e) + resifted = defaultdict(list) + for k in sifted: + for f, e in sifted[k]: + e = e.xreplace(reps) + f = e.free_symbols + resifted[len(f)].append((f, e)) + sifted = resifted + for k in sifted: + eqs.extend([e for f, e in sifted[k]]) + nonlineqs = [ei.subs(reps) for ei in nonlineqs] + other = [ei.subs(reps) for ei in other] + rv = rv.func(*([i.canonical for i in (eqs + nonlineqs + other)] + nonRel)) + patterns = _simplify_patterns_and() + threeterm_patterns = _simplify_patterns_and3() + return _apply_patternbased_simplification(rv, patterns, + measure, false, + threeterm_patterns=threeterm_patterns) + + def _eval_as_set(self): + from sympy.sets.sets import Intersection + return Intersection(*[arg.as_set() for arg in self.args]) + + def _eval_rewrite_as_Nor(self, *args, **kwargs): + return Nor(*[Not(arg) for arg in self.args]) + + def to_anf(self, deep=True): + if deep: + result = And._to_anf(*self.args, deep=deep) + return distribute_xor_over_and(result) + return self + + +class Or(LatticeOp, BooleanFunction): + """ + Logical OR function + + It evaluates its arguments in order, returning true immediately + when an argument is true, and false if they are all false. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import Or + >>> x | y + x | y + + Notes + ===== + + The ``|`` operator is provided as a convenience, but note that its use + here is different from its normal use in Python, which is bitwise + or. Hence, ``Or(a, b)`` and ``a | b`` will return different things if + ``a`` and ``b`` are integers. + + >>> Or(x, y).subs(x, 0) + y + + """ + zero = true + identity = false + + if TYPE_CHECKING: + + def __new__(cls, *args: Boolean | bool) -> Boolean: # type: ignore + ... + + @property + def args(self) -> tuple[Boolean, ...]: + ... + + @classmethod + def _new_args_filter(cls, args): + newargs = [] + rel = [] + args = BooleanFunction.binary_check_and_simplify(*args) + for x in args: + if x.is_Relational: + c = x.canonical + if c in rel: + continue + nc = c.negated.canonical + if any(r == nc for r in rel): + return [true] + rel.append(c) + newargs.append(x) + return LatticeOp._new_args_filter(newargs, Or) + + def _eval_subs(self, old, new): + args = [] + bad = None + for i in self.args: + try: + i = i.subs(old, new) + except TypeError: + # store TypeError + if bad is None: + bad = i + continue + if i == True: + return true + elif i != False: + args.append(i) + if bad is not None: + # let it raise + bad.subs(old, new) + # If old is Or, replace the parts of the arguments with new if all + # are there + if isinstance(old, Or): + old_set = set(old.args) + if old_set.issubset(args): + args = set(args) - old_set + args.add(new) + + return self.func(*args) + + def _eval_as_set(self): + from sympy.sets.sets import Union + return Union(*[arg.as_set() for arg in self.args]) + + def _eval_rewrite_as_Nand(self, *args, **kwargs): + return Nand(*[Not(arg) for arg in self.args]) + + def _eval_simplify(self, **kwargs): + from sympy.core.relational import Le, Ge, Eq + lege = self.atoms(Le, Ge) + if lege: + reps = {i: self.func( + Eq(i.lhs, i.rhs), i.strict) for i in lege} + return self.xreplace(reps)._eval_simplify(**kwargs) + # standard simplify + rv = super()._eval_simplify(**kwargs) + if not isinstance(rv, Or): + return rv + patterns = _simplify_patterns_or() + return _apply_patternbased_simplification(rv, patterns, + kwargs['measure'], true) + + def to_anf(self, deep=True): + args = range(1, len(self.args) + 1) + args = (combinations(self.args, j) for j in args) + args = chain.from_iterable(args) # powerset + args = (And(*arg) for arg in args) + args = (to_anf(x, deep=deep) if deep else x for x in args) + return Xor(*list(args), remove_true=False) + + +class Not(BooleanFunction): + """ + Logical Not function (negation) + + + Returns ``true`` if the statement is ``false`` or ``False``. + Returns ``false`` if the statement is ``true`` or ``True``. + + Examples + ======== + + >>> from sympy import Not, And, Or + >>> from sympy.abc import x, A, B + >>> Not(True) + False + >>> Not(False) + True + >>> Not(And(True, False)) + True + >>> Not(Or(True, False)) + False + >>> Not(And(And(True, x), Or(x, False))) + ~x + >>> ~x + ~x + >>> Not(And(Or(A, B), Or(~A, ~B))) + ~((A | B) & (~A | ~B)) + + Notes + ===== + + - The ``~`` operator is provided as a convenience, but note that its use + here is different from its normal use in Python, which is bitwise + not. In particular, ``~a`` and ``Not(a)`` will be different if ``a`` is + an integer. Furthermore, since bools in Python subclass from ``int``, + ``~True`` is the same as ``~1`` which is ``-2``, which has a boolean + value of True. To avoid this issue, use the SymPy boolean types + ``true`` and ``false``. + + - As of Python 3.12, the bitwise not operator ``~`` used on a + Python ``bool`` is deprecated and will emit a warning. + + >>> from sympy import true + >>> ~True # doctest: +SKIP + -2 + >>> ~true + False + + """ + + is_Not = True + + @classmethod + def eval(cls, arg): + if isinstance(arg, Number) or arg in (True, False): + return false if arg else true + if arg.is_Not: + return arg.args[0] + # Simplify Relational objects. + if arg.is_Relational: + return arg.negated + + def _eval_as_set(self): + """ + Rewrite logic operators and relationals in terms of real sets. + + Examples + ======== + + >>> from sympy import Not, Symbol + >>> x = Symbol('x') + >>> Not(x > 0).as_set() + Interval(-oo, 0) + """ + return self.args[0].as_set().complement(S.Reals) + + def to_nnf(self, simplify=True): + if is_literal(self): + return self + + expr = self.args[0] + + func, args = expr.func, expr.args + + if func == And: + return Or._to_nnf(*[Not(arg) for arg in args], simplify=simplify) + + if func == Or: + return And._to_nnf(*[Not(arg) for arg in args], simplify=simplify) + + if func == Implies: + a, b = args + return And._to_nnf(a, Not(b), simplify=simplify) + + if func == Equivalent: + return And._to_nnf(Or(*args), Or(*[Not(arg) for arg in args]), + simplify=simplify) + + if func == Xor: + result = [] + for i in range(1, len(args)+1, 2): + for neg in combinations(args, i): + clause = [Not(s) if s in neg else s for s in args] + result.append(Or(*clause)) + return And._to_nnf(*result, simplify=simplify) + + if func == ITE: + a, b, c = args + return And._to_nnf(Or(a, Not(c)), Or(Not(a), Not(b)), simplify=simplify) + + raise ValueError("Illegal operator %s in expression" % func) + + def to_anf(self, deep=True): + return Xor._to_anf(true, self.args[0], deep=deep) + + +class Xor(BooleanFunction): + """ + Logical XOR (exclusive OR) function. + + + Returns True if an odd number of the arguments are True and the rest are + False. + + Returns False if an even number of the arguments are True and the rest are + False. + + Examples + ======== + + >>> from sympy.logic.boolalg import Xor + >>> from sympy import symbols + >>> x, y = symbols('x y') + >>> Xor(True, False) + True + >>> Xor(True, True) + False + >>> Xor(True, False, True, True, False) + True + >>> Xor(True, False, True, False) + False + >>> x ^ y + x ^ y + + Notes + ===== + + The ``^`` operator is provided as a convenience, but note that its use + here is different from its normal use in Python, which is bitwise xor. In + particular, ``a ^ b`` and ``Xor(a, b)`` will be different if ``a`` and + ``b`` are integers. + + >>> Xor(x, y).subs(y, 0) + x + + """ + def __new__(cls, *args, remove_true=True, **kwargs): + argset = set() + obj = super().__new__(cls, *args, **kwargs) + for arg in obj._args: + if isinstance(arg, Number) or arg in (True, False): + if arg: + arg = true + else: + continue + if isinstance(arg, Xor): + for a in arg.args: + argset.remove(a) if a in argset else argset.add(a) + elif arg in argset: + argset.remove(arg) + else: + argset.add(arg) + rel = [(r, r.canonical, r.negated.canonical) + for r in argset if r.is_Relational] + odd = False # is number of complimentary pairs odd? start 0 -> False + remove = [] + for i, (r, c, nc) in enumerate(rel): + for j in range(i + 1, len(rel)): + rj, cj = rel[j][:2] + if cj == nc: + odd = not odd + break + elif cj == c: + break + else: + continue + remove.append((r, rj)) + if odd: + argset.remove(true) if true in argset else argset.add(true) + for a, b in remove: + argset.remove(a) + argset.remove(b) + if len(argset) == 0: + return false + elif len(argset) == 1: + return argset.pop() + elif True in argset and remove_true: + argset.remove(True) + return Not(Xor(*argset)) + else: + obj._args = tuple(ordered(argset)) + obj._argset = frozenset(argset) + return obj + + # XXX: This should be cached on the object rather than using cacheit + # Maybe it can be computed in __new__? + @property # type: ignore + @cacheit + def args(self): + return tuple(ordered(self._argset)) + + def to_nnf(self, simplify=True): + args = [] + for i in range(0, len(self.args)+1, 2): + for neg in combinations(self.args, i): + clause = [Not(s) if s in neg else s for s in self.args] + args.append(Or(*clause)) + return And._to_nnf(*args, simplify=simplify) + + def _eval_rewrite_as_Or(self, *args, **kwargs): + a = self.args + return Or(*[_convert_to_varsSOP(x, self.args) + for x in _get_odd_parity_terms(len(a))]) + + def _eval_rewrite_as_And(self, *args, **kwargs): + a = self.args + return And(*[_convert_to_varsPOS(x, self.args) + for x in _get_even_parity_terms(len(a))]) + + def _eval_simplify(self, **kwargs): + # as standard simplify uses simplify_logic which writes things as + # And and Or, we only simplify the partial expressions before using + # patterns + rv = self.func(*[a.simplify(**kwargs) for a in self.args]) + rv = rv.to_anf() + if not isinstance(rv, Xor): # This shouldn't really happen here + return rv + patterns = _simplify_patterns_xor() + return _apply_patternbased_simplification(rv, patterns, + kwargs['measure'], None) + + def _eval_subs(self, old, new): + # If old is Xor, replace the parts of the arguments with new if all + # are there + if isinstance(old, Xor): + old_set = set(old.args) + if old_set.issubset(self.args): + args = set(self.args) - old_set + args.add(new) + return self.func(*args) + + +class Nand(BooleanFunction): + """ + Logical NAND function. + + It evaluates its arguments in order, giving True immediately if any + of them are False, and False if they are all True. + + Returns True if any of the arguments are False + Returns False if all arguments are True + + Examples + ======== + + >>> from sympy.logic.boolalg import Nand + >>> from sympy import symbols + >>> x, y = symbols('x y') + >>> Nand(False, True) + True + >>> Nand(True, True) + False + >>> Nand(x, y) + ~(x & y) + + """ + @classmethod + def eval(cls, *args): + return Not(And(*args)) + + +class Nor(BooleanFunction): + """ + Logical NOR function. + + It evaluates its arguments in order, giving False immediately if any + of them are True, and True if they are all False. + + Returns False if any argument is True + Returns True if all arguments are False + + Examples + ======== + + >>> from sympy.logic.boolalg import Nor + >>> from sympy import symbols + >>> x, y = symbols('x y') + + >>> Nor(True, False) + False + >>> Nor(True, True) + False + >>> Nor(False, True) + False + >>> Nor(False, False) + True + >>> Nor(x, y) + ~(x | y) + + """ + @classmethod + def eval(cls, *args): + return Not(Or(*args)) + + +class Xnor(BooleanFunction): + """ + Logical XNOR function. + + Returns False if an odd number of the arguments are True and the rest are + False. + + Returns True if an even number of the arguments are True and the rest are + False. + + Examples + ======== + + >>> from sympy.logic.boolalg import Xnor + >>> from sympy import symbols + >>> x, y = symbols('x y') + >>> Xnor(True, False) + False + >>> Xnor(True, True) + True + >>> Xnor(True, False, True, True, False) + False + >>> Xnor(True, False, True, False) + True + + """ + @classmethod + def eval(cls, *args): + return Not(Xor(*args)) + + +class Implies(BooleanFunction): + r""" + Logical implication. + + A implies B is equivalent to if A then B. Mathematically, it is written + as `A \Rightarrow B` and is equivalent to `\neg A \vee B` or ``~A | B``. + + Accepts two Boolean arguments; A and B. + Returns False if A is True and B is False + Returns True otherwise. + + Examples + ======== + + >>> from sympy.logic.boolalg import Implies + >>> from sympy import symbols + >>> x, y = symbols('x y') + + >>> Implies(True, False) + False + >>> Implies(False, False) + True + >>> Implies(True, True) + True + >>> Implies(False, True) + True + >>> x >> y + Implies(x, y) + >>> y << x + Implies(x, y) + + Notes + ===== + + The ``>>`` and ``<<`` operators are provided as a convenience, but note + that their use here is different from their normal use in Python, which is + bit shifts. Hence, ``Implies(a, b)`` and ``a >> b`` will return different + things if ``a`` and ``b`` are integers. In particular, since Python + considers ``True`` and ``False`` to be integers, ``True >> True`` will be + the same as ``1 >> 1``, i.e., 0, which has a truth value of False. To + avoid this issue, use the SymPy objects ``true`` and ``false``. + + >>> from sympy import true, false + >>> True >> False + 1 + >>> true >> false + False + + """ + @classmethod + def eval(cls, *args): + try: + newargs = [] + for x in args: + if isinstance(x, Number) or x in (0, 1): + newargs.append(bool(x)) + else: + newargs.append(x) + A, B = newargs + except ValueError: + raise ValueError( + "%d operand(s) used for an Implies " + "(pairs are required): %s" % (len(args), str(args))) + if A in (True, False) or B in (True, False): + return Or(Not(A), B) + elif A == B: + return true + elif A.is_Relational and B.is_Relational: + if A.canonical == B.canonical: + return true + if A.negated.canonical == B.canonical: + return B + else: + return Basic.__new__(cls, *args) + + def to_nnf(self, simplify=True): + a, b = self.args + return Or._to_nnf(Not(a), b, simplify=simplify) + + def to_anf(self, deep=True): + a, b = self.args + return Xor._to_anf(true, a, And(a, b), deep=deep) + + +class Equivalent(BooleanFunction): + """ + Equivalence relation. + + ``Equivalent(A, B)`` is True iff A and B are both True or both False. + + Returns True if all of the arguments are logically equivalent. + Returns False otherwise. + + For two arguments, this is equivalent to :py:class:`~.Xnor`. + + Examples + ======== + + >>> from sympy.logic.boolalg import Equivalent, And + >>> from sympy.abc import x + >>> Equivalent(False, False, False) + True + >>> Equivalent(True, False, False) + False + >>> Equivalent(x, And(x, True)) + True + + """ + def __new__(cls, *args, **options): + from sympy.core.relational import Relational + args = [_sympify(arg) for arg in args] + + argset = set(args) + for x in args: + if isinstance(x, Number) or x in [True, False]: # Includes 0, 1 + argset.discard(x) + argset.add(bool(x)) + rel = [] + for r in argset: + if isinstance(r, Relational): + rel.append((r, r.canonical, r.negated.canonical)) + remove = [] + for i, (r, c, nc) in enumerate(rel): + for j in range(i + 1, len(rel)): + rj, cj = rel[j][:2] + if cj == nc: + return false + elif cj == c: + remove.append((r, rj)) + break + for a, b in remove: + argset.remove(a) + argset.remove(b) + argset.add(True) + if len(argset) <= 1: + return true + if True in argset: + argset.discard(True) + return And(*argset) + if False in argset: + argset.discard(False) + return And(*[Not(arg) for arg in argset]) + _args = frozenset(argset) + obj = super().__new__(cls, _args) + obj._argset = _args + return obj + + # XXX: This should be cached on the object rather than using cacheit + # Maybe it can be computed in __new__? + @property # type: ignore + @cacheit + def args(self): + return tuple(ordered(self._argset)) + + def to_nnf(self, simplify=True): + args = [] + for a, b in zip(self.args, self.args[1:]): + args.append(Or(Not(a), b)) + args.append(Or(Not(self.args[-1]), self.args[0])) + return And._to_nnf(*args, simplify=simplify) + + def to_anf(self, deep=True): + a = And(*self.args) + b = And(*[to_anf(Not(arg), deep=False) for arg in self.args]) + b = distribute_xor_over_and(b) + return Xor._to_anf(a, b, deep=deep) + + +class ITE(BooleanFunction): + """ + If-then-else clause. + + ``ITE(A, B, C)`` evaluates and returns the result of B if A is true + else it returns the result of C. All args must be Booleans. + + From a logic gate perspective, ITE corresponds to a 2-to-1 multiplexer, + where A is the select signal. + + Examples + ======== + + >>> from sympy.logic.boolalg import ITE, And, Xor, Or + >>> from sympy.abc import x, y, z + >>> ITE(True, False, True) + False + >>> ITE(Or(True, False), And(True, True), Xor(True, True)) + True + >>> ITE(x, y, z) + ITE(x, y, z) + >>> ITE(True, x, y) + x + >>> ITE(False, x, y) + y + >>> ITE(x, y, y) + y + + Trying to use non-Boolean args will generate a TypeError: + + >>> ITE(True, [], ()) + Traceback (most recent call last): + ... + TypeError: expecting bool, Boolean or ITE, not `[]` + + """ + def __new__(cls, *args, **kwargs): + from sympy.core.relational import Eq, Ne + if len(args) != 3: + raise ValueError('expecting exactly 3 args') + a, b, c = args + # check use of binary symbols + if isinstance(a, (Eq, Ne)): + # in this context, we can evaluate the Eq/Ne + # if one arg is a binary symbol and the other + # is true/false + b, c = map(as_Boolean, (b, c)) + bin_syms = set().union(*[i.binary_symbols for i in (b, c)]) + if len(set(a.args) - bin_syms) == 1: + # one arg is a binary_symbols + _a = a + if a.lhs is true: + a = a.rhs + elif a.rhs is true: + a = a.lhs + elif a.lhs is false: + a = Not(a.rhs) + elif a.rhs is false: + a = Not(a.lhs) + else: + # binary can only equal True or False + a = false + if isinstance(_a, Ne): + a = Not(a) + else: + a, b, c = BooleanFunction.binary_check_and_simplify( + a, b, c) + rv = None + if kwargs.get('evaluate', True): + rv = cls.eval(a, b, c) + if rv is None: + rv = BooleanFunction.__new__(cls, a, b, c, evaluate=False) + return rv + + @classmethod + def eval(cls, *args): + from sympy.core.relational import Eq, Ne + # do the args give a singular result? + a, b, c = args + if isinstance(a, (Ne, Eq)): + _a = a + if true in a.args: + a = a.lhs if a.rhs is true else a.rhs + elif false in a.args: + a = Not(a.lhs) if a.rhs is false else Not(a.rhs) + else: + _a = None + if _a is not None and isinstance(_a, Ne): + a = Not(a) + if a is true: + return b + if a is false: + return c + if b == c: + return b + else: + # or maybe the results allow the answer to be expressed + # in terms of the condition + if b is true and c is false: + return a + if b is false and c is true: + return Not(a) + if [a, b, c] != args: + return cls(a, b, c, evaluate=False) + + def to_nnf(self, simplify=True): + a, b, c = self.args + return And._to_nnf(Or(Not(a), b), Or(a, c), simplify=simplify) + + def _eval_as_set(self): + return self.to_nnf().as_set() + + def _eval_rewrite_as_Piecewise(self, *args, **kwargs): + from sympy.functions.elementary.piecewise import Piecewise + return Piecewise((args[1], args[0]), (args[2], True)) + + +class Exclusive(BooleanFunction): + """ + True if only one or no argument is true. + + ``Exclusive(A, B, C)`` is equivalent to ``~(A & B) & ~(A & C) & ~(B & C)``. + + For two arguments, this is equivalent to :py:class:`~.Xor`. + + Examples + ======== + + >>> from sympy.logic.boolalg import Exclusive + >>> Exclusive(False, False, False) + True + >>> Exclusive(False, True, False) + True + >>> Exclusive(False, True, True) + False + + """ + @classmethod + def eval(cls, *args): + and_args = [] + for a, b in combinations(args, 2): + and_args.append(Not(And(a, b))) + return And(*and_args) + + +# end class definitions. Some useful methods + + +def conjuncts(expr): + """Return a list of the conjuncts in ``expr``. + + Examples + ======== + + >>> from sympy.logic.boolalg import conjuncts + >>> from sympy.abc import A, B + >>> conjuncts(A & B) + frozenset({A, B}) + >>> conjuncts(A | B) + frozenset({A | B}) + + """ + return And.make_args(expr) + + +def disjuncts(expr): + """Return a list of the disjuncts in ``expr``. + + Examples + ======== + + >>> from sympy.logic.boolalg import disjuncts + >>> from sympy.abc import A, B + >>> disjuncts(A | B) + frozenset({A, B}) + >>> disjuncts(A & B) + frozenset({A & B}) + + """ + return Or.make_args(expr) + + +def distribute_and_over_or(expr): + """ + Given a sentence ``expr`` consisting of conjunctions and disjunctions + of literals, return an equivalent sentence in CNF. + + Examples + ======== + + >>> from sympy.logic.boolalg import distribute_and_over_or, And, Or, Not + >>> from sympy.abc import A, B, C + >>> distribute_and_over_or(Or(A, And(Not(B), Not(C)))) + (A | ~B) & (A | ~C) + + """ + return _distribute((expr, And, Or)) + + +def distribute_or_over_and(expr): + """ + Given a sentence ``expr`` consisting of conjunctions and disjunctions + of literals, return an equivalent sentence in DNF. + + Note that the output is NOT simplified. + + Examples + ======== + + >>> from sympy.logic.boolalg import distribute_or_over_and, And, Or, Not + >>> from sympy.abc import A, B, C + >>> distribute_or_over_and(And(Or(Not(A), B), C)) + (B & C) | (C & ~A) + + """ + return _distribute((expr, Or, And)) + + +def distribute_xor_over_and(expr): + """ + Given a sentence ``expr`` consisting of conjunction and + exclusive disjunctions of literals, return an + equivalent exclusive disjunction. + + Note that the output is NOT simplified. + + Examples + ======== + + >>> from sympy.logic.boolalg import distribute_xor_over_and, And, Xor, Not + >>> from sympy.abc import A, B, C + >>> distribute_xor_over_and(And(Xor(Not(A), B), C)) + (B & C) ^ (C & ~A) + """ + return _distribute((expr, Xor, And)) + + +def _distribute(info): + """ + Distributes ``info[1]`` over ``info[2]`` with respect to ``info[0]``. + """ + if isinstance(info[0], info[2]): + for arg in info[0].args: + if isinstance(arg, info[1]): + conj = arg + break + else: + return info[0] + rest = info[2](*[a for a in info[0].args if a is not conj]) + return info[1](*list(map(_distribute, + [(info[2](c, rest), info[1], info[2]) + for c in conj.args])), remove_true=False) + elif isinstance(info[0], info[1]): + return info[1](*list(map(_distribute, + [(x, info[1], info[2]) + for x in info[0].args])), + remove_true=False) + else: + return info[0] + + +def to_anf(expr, deep=True): + r""" + Converts expr to Algebraic Normal Form (ANF). + + ANF is a canonical normal form, which means that two + equivalent formulas will convert to the same ANF. + + A logical expression is in ANF if it has the form + + .. math:: 1 \oplus a \oplus b \oplus ab \oplus abc + + i.e. it can be: + - purely true, + - purely false, + - conjunction of variables, + - exclusive disjunction. + + The exclusive disjunction can only contain true, variables + or conjunction of variables. No negations are permitted. + + If ``deep`` is ``False``, arguments of the boolean + expression are considered variables, i.e. only the + top-level expression is converted to ANF. + + Examples + ======== + >>> from sympy.logic.boolalg import And, Or, Not, Implies, Equivalent + >>> from sympy.logic.boolalg import to_anf + >>> from sympy.abc import A, B, C + >>> to_anf(Not(A)) + A ^ True + >>> to_anf(And(Or(A, B), Not(C))) + A ^ B ^ (A & B) ^ (A & C) ^ (B & C) ^ (A & B & C) + >>> to_anf(Implies(Not(A), Equivalent(B, C)), deep=False) + True ^ ~A ^ (~A & (Equivalent(B, C))) + + """ + expr = sympify(expr) + + if is_anf(expr): + return expr + return expr.to_anf(deep=deep) + + +def to_nnf(expr, simplify=True): + """ + Converts ``expr`` to Negation Normal Form (NNF). + + A logical expression is in NNF if it + contains only :py:class:`~.And`, :py:class:`~.Or` and :py:class:`~.Not`, + and :py:class:`~.Not` is applied only to literals. + If ``simplify`` is ``True``, the result contains no redundant clauses. + + Examples + ======== + + >>> from sympy.abc import A, B, C, D + >>> from sympy.logic.boolalg import Not, Equivalent, to_nnf + >>> to_nnf(Not((~A & ~B) | (C & D))) + (A | B) & (~C | ~D) + >>> to_nnf(Equivalent(A >> B, B >> A)) + (A | ~B | (A & ~B)) & (B | ~A | (B & ~A)) + + """ + if is_nnf(expr, simplify): + return expr + return expr.to_nnf(simplify) + + +def to_cnf(expr, simplify=False, force=False): + """ + Convert a propositional logical sentence ``expr`` to conjunctive normal + form: ``((A | ~B | ...) & (B | C | ...) & ...)``. + If ``simplify`` is ``True``, ``expr`` is evaluated to its simplest CNF + form using the Quine-McCluskey algorithm; this may take a long + time. If there are more than 8 variables the ``force`` flag must be set + to ``True`` to simplify (default is ``False``). + + Examples + ======== + + >>> from sympy.logic.boolalg import to_cnf + >>> from sympy.abc import A, B, D + >>> to_cnf(~(A | B) | D) + (D | ~A) & (D | ~B) + >>> to_cnf((A | B) & (A | ~A), True) + A | B + + """ + expr = sympify(expr) + if not isinstance(expr, BooleanFunction): + return expr + + if simplify: + if not force and len(_find_predicates(expr)) > 8: + raise ValueError(filldedent(''' + To simplify a logical expression with more + than 8 variables may take a long time and requires + the use of `force=True`.''')) + return simplify_logic(expr, 'cnf', True, force=force) + + # Don't convert unless we have to + if is_cnf(expr): + return expr + + expr = eliminate_implications(expr) + res = distribute_and_over_or(expr) + + return res + + +def to_dnf(expr, simplify=False, force=False): + """ + Convert a propositional logical sentence ``expr`` to disjunctive normal + form: ``((A & ~B & ...) | (B & C & ...) | ...)``. + If ``simplify`` is ``True``, ``expr`` is evaluated to its simplest DNF form using + the Quine-McCluskey algorithm; this may take a long + time. If there are more than 8 variables, the ``force`` flag must be set to + ``True`` to simplify (default is ``False``). + + Examples + ======== + + >>> from sympy.logic.boolalg import to_dnf + >>> from sympy.abc import A, B, C + >>> to_dnf(B & (A | C)) + (A & B) | (B & C) + >>> to_dnf((A & B) | (A & ~B) | (B & C) | (~B & C), True) + A | C + + """ + expr = sympify(expr) + if not isinstance(expr, BooleanFunction): + return expr + + if simplify: + if not force and len(_find_predicates(expr)) > 8: + raise ValueError(filldedent(''' + To simplify a logical expression with more + than 8 variables may take a long time and requires + the use of `force=True`.''')) + return simplify_logic(expr, 'dnf', True, force=force) + + # Don't convert unless we have to + if is_dnf(expr): + return expr + + expr = eliminate_implications(expr) + return distribute_or_over_and(expr) + + +def is_anf(expr): + r""" + Checks if ``expr`` is in Algebraic Normal Form (ANF). + + A logical expression is in ANF if it has the form + + .. math:: 1 \oplus a \oplus b \oplus ab \oplus abc + + i.e. it is purely true, purely false, conjunction of + variables or exclusive disjunction. The exclusive + disjunction can only contain true, variables or + conjunction of variables. No negations are permitted. + + Examples + ======== + + >>> from sympy.logic.boolalg import And, Not, Xor, true, is_anf + >>> from sympy.abc import A, B, C + >>> is_anf(true) + True + >>> is_anf(A) + True + >>> is_anf(And(A, B, C)) + True + >>> is_anf(Xor(A, Not(B))) + False + + """ + expr = sympify(expr) + + if is_literal(expr) and not isinstance(expr, Not): + return True + + if isinstance(expr, And): + for arg in expr.args: + if not arg.is_Symbol: + return False + return True + + elif isinstance(expr, Xor): + for arg in expr.args: + if isinstance(arg, And): + for a in arg.args: + if not a.is_Symbol: + return False + elif is_literal(arg): + if isinstance(arg, Not): + return False + else: + return False + return True + + else: + return False + + +def is_nnf(expr, simplified=True): + """ + Checks if ``expr`` is in Negation Normal Form (NNF). + + A logical expression is in NNF if it + contains only :py:class:`~.And`, :py:class:`~.Or` and :py:class:`~.Not`, + and :py:class:`~.Not` is applied only to literals. + If ``simplified`` is ``True``, checks if result contains no redundant clauses. + + Examples + ======== + + >>> from sympy.abc import A, B, C + >>> from sympy.logic.boolalg import Not, is_nnf + >>> is_nnf(A & B | ~C) + True + >>> is_nnf((A | ~A) & (B | C)) + False + >>> is_nnf((A | ~A) & (B | C), False) + True + >>> is_nnf(Not(A & B) | C) + False + >>> is_nnf((A >> B) & (B >> A)) + False + + """ + + expr = sympify(expr) + if is_literal(expr): + return True + + stack = [expr] + + while stack: + expr = stack.pop() + if expr.func in (And, Or): + if simplified: + args = expr.args + for arg in args: + if Not(arg) in args: + return False + stack.extend(expr.args) + + elif not is_literal(expr): + return False + + return True + + +def is_cnf(expr): + """ + Test whether or not an expression is in conjunctive normal form. + + Examples + ======== + + >>> from sympy.logic.boolalg import is_cnf + >>> from sympy.abc import A, B, C + >>> is_cnf(A | B | C) + True + >>> is_cnf(A & B & C) + True + >>> is_cnf((A & B) | C) + False + + """ + return _is_form(expr, And, Or) + + +def is_dnf(expr): + """ + Test whether or not an expression is in disjunctive normal form. + + Examples + ======== + + >>> from sympy.logic.boolalg import is_dnf + >>> from sympy.abc import A, B, C + >>> is_dnf(A | B | C) + True + >>> is_dnf(A & B & C) + True + >>> is_dnf((A & B) | C) + True + >>> is_dnf(A & (B | C)) + False + + """ + return _is_form(expr, Or, And) + + +def _is_form(expr, function1, function2): + """ + Test whether or not an expression is of the required form. + + """ + expr = sympify(expr) + + vals = function1.make_args(expr) if isinstance(expr, function1) else [expr] + for lit in vals: + if isinstance(lit, function2): + vals2 = function2.make_args(lit) if isinstance(lit, function2) else [lit] + for l in vals2: + if is_literal(l) is False: + return False + elif is_literal(lit) is False: + return False + + return True + + +def eliminate_implications(expr): + """ + Change :py:class:`~.Implies` and :py:class:`~.Equivalent` into + :py:class:`~.And`, :py:class:`~.Or`, and :py:class:`~.Not`. + That is, return an expression that is equivalent to ``expr``, but has only + ``&``, ``|``, and ``~`` as logical + operators. + + Examples + ======== + + >>> from sympy.logic.boolalg import Implies, Equivalent, \ + eliminate_implications + >>> from sympy.abc import A, B, C + >>> eliminate_implications(Implies(A, B)) + B | ~A + >>> eliminate_implications(Equivalent(A, B)) + (A | ~B) & (B | ~A) + >>> eliminate_implications(Equivalent(A, B, C)) + (A | ~C) & (B | ~A) & (C | ~B) + + """ + return to_nnf(expr, simplify=False) + + +def is_literal(expr): + """ + Returns True if expr is a literal, else False. + + Examples + ======== + + >>> from sympy import Or, Q + >>> from sympy.abc import A, B + >>> from sympy.logic.boolalg import is_literal + >>> is_literal(A) + True + >>> is_literal(~A) + True + >>> is_literal(Q.zero(A)) + True + >>> is_literal(A + B) + True + >>> is_literal(Or(A, B)) + False + + """ + from sympy.assumptions import AppliedPredicate + + if isinstance(expr, Not): + return is_literal(expr.args[0]) + elif expr in (True, False) or isinstance(expr, AppliedPredicate) or expr.is_Atom: + return True + elif not isinstance(expr, BooleanFunction) and all( + (isinstance(expr, AppliedPredicate) or a.is_Atom) for a in expr.args): + return True + return False + + +def to_int_repr(clauses, symbols): + """ + Takes clauses in CNF format and puts them into an integer representation. + + Examples + ======== + + >>> from sympy.logic.boolalg import to_int_repr + >>> from sympy.abc import x, y + >>> to_int_repr([x | y, y], [x, y]) == [{1, 2}, {2}] + True + + """ + + # Convert the symbol list into a dict + symbols = dict(zip(symbols, range(1, len(symbols) + 1))) + + def append_symbol(arg, symbols): + if isinstance(arg, Not): + return -symbols[arg.args[0]] + else: + return symbols[arg] + + return [{append_symbol(arg, symbols) for arg in Or.make_args(c)} + for c in clauses] + + +def term_to_integer(term): + """ + Return an integer corresponding to the base-2 digits given by *term*. + + Parameters + ========== + + term : a string or list of ones and zeros + + Examples + ======== + + >>> from sympy.logic.boolalg import term_to_integer + >>> term_to_integer([1, 0, 0]) + 4 + >>> term_to_integer('100') + 4 + + """ + + return int(''.join(list(map(str, list(term)))), 2) + + +integer_to_term = ibin # XXX could delete? + + +def truth_table(expr, variables, input=True): + """ + Return a generator of all possible configurations of the input variables, + and the result of the boolean expression for those values. + + Parameters + ========== + + expr : Boolean expression + + variables : list of variables + + input : bool (default ``True``) + Indicates whether to return the input combinations. + + Examples + ======== + + >>> from sympy.logic.boolalg import truth_table + >>> from sympy.abc import x,y + >>> table = truth_table(x >> y, [x, y]) + >>> for t in table: + ... print('{0} -> {1}'.format(*t)) + [0, 0] -> True + [0, 1] -> True + [1, 0] -> False + [1, 1] -> True + + >>> table = truth_table(x | y, [x, y]) + >>> list(table) + [([0, 0], False), ([0, 1], True), ([1, 0], True), ([1, 1], True)] + + If ``input`` is ``False``, ``truth_table`` returns only a list of truth values. + In this case, the corresponding input values of variables can be + deduced from the index of a given output. + + >>> from sympy.utilities.iterables import ibin + >>> vars = [y, x] + >>> values = truth_table(x >> y, vars, input=False) + >>> values = list(values) + >>> values + [True, False, True, True] + + >>> for i, value in enumerate(values): + ... print('{0} -> {1}'.format(list(zip( + ... vars, ibin(i, len(vars)))), value)) + [(y, 0), (x, 0)] -> True + [(y, 0), (x, 1)] -> False + [(y, 1), (x, 0)] -> True + [(y, 1), (x, 1)] -> True + + """ + variables = [sympify(v) for v in variables] + + expr = sympify(expr) + if not isinstance(expr, BooleanFunction) and not is_literal(expr): + return + + table = product((0, 1), repeat=len(variables)) + for term in table: + value = expr.xreplace(dict(zip(variables, term))) + + if input: + yield list(term), value + else: + yield value + + +def _check_pair(minterm1, minterm2): + """ + Checks if a pair of minterms differs by only one bit. If yes, returns + index, else returns `-1`. + """ + # Early termination seems to be faster than list comprehension, + # at least for large examples. + index = -1 + for x, i in enumerate(minterm1): # zip(minterm1, minterm2) is slower + if i != minterm2[x]: + if index == -1: + index = x + else: + return -1 + return index + + +def _convert_to_varsSOP(minterm, variables): + """ + Converts a term in the expansion of a function from binary to its + variable form (for SOP). + """ + temp = [variables[n] if val == 1 else Not(variables[n]) + for n, val in enumerate(minterm) if val != 3] + return And(*temp) + + +def _convert_to_varsPOS(maxterm, variables): + """ + Converts a term in the expansion of a function from binary to its + variable form (for POS). + """ + temp = [variables[n] if val == 0 else Not(variables[n]) + for n, val in enumerate(maxterm) if val != 3] + return Or(*temp) + + +def _convert_to_varsANF(term, variables): + """ + Converts a term in the expansion of a function from binary to its + variable form (for ANF). + + Parameters + ========== + + term : list of 1's and 0's (complementation pattern) + variables : list of variables + + """ + temp = [variables[n] for n, t in enumerate(term) if t == 1] + + if not temp: + return true + + return And(*temp) + + +def _get_odd_parity_terms(n): + """ + Returns a list of lists, with all possible combinations of n zeros and ones + with an odd number of ones. + """ + return [e for e in [ibin(i, n) for i in range(2**n)] if sum(e) % 2 == 1] + + +def _get_even_parity_terms(n): + """ + Returns a list of lists, with all possible combinations of n zeros and ones + with an even number of ones. + """ + return [e for e in [ibin(i, n) for i in range(2**n)] if sum(e) % 2 == 0] + + +def _simplified_pairs(terms): + """ + Reduces a set of minterms, if possible, to a simplified set of minterms + with one less variable in the terms using QM method. + """ + if not terms: + return [] + + simplified_terms = [] + todo = list(range(len(terms))) + + # Count number of ones as _check_pair can only potentially match if there + # is at most a difference of a single one + termdict = defaultdict(list) + for n, term in enumerate(terms): + ones = sum(1 for t in term if t == 1) + termdict[ones].append(n) + + variables = len(terms[0]) + for k in range(variables): + for i in termdict[k]: + for j in termdict[k+1]: + index = _check_pair(terms[i], terms[j]) + if index != -1: + # Mark terms handled + todo[i] = todo[j] = None + # Copy old term + newterm = terms[i][:] + # Set differing position to don't care + newterm[index] = 3 + # Add if not already there + if newterm not in simplified_terms: + simplified_terms.append(newterm) + + if simplified_terms: + # Further simplifications only among the new terms + simplified_terms = _simplified_pairs(simplified_terms) + + # Add remaining, non-simplified, terms + simplified_terms.extend([terms[i] for i in todo if i is not None]) + return simplified_terms + + +def _rem_redundancy(l1, terms): + """ + After the truth table has been sufficiently simplified, use the prime + implicant table method to recognize and eliminate redundant pairs, + and return the essential arguments. + """ + + if not terms: + return [] + + nterms = len(terms) + nl1 = len(l1) + + # Create dominating matrix + dommatrix = [[0]*nl1 for n in range(nterms)] + colcount = [0]*nl1 + rowcount = [0]*nterms + for primei, prime in enumerate(l1): + for termi, term in enumerate(terms): + # Check prime implicant covering term + if all(t == 3 or t == mt for t, mt in zip(prime, term)): + dommatrix[termi][primei] = 1 + colcount[primei] += 1 + rowcount[termi] += 1 + + # Keep track if anything changed + anythingchanged = True + # Then, go again + while anythingchanged: + anythingchanged = False + + for rowi in range(nterms): + # Still non-dominated? + if rowcount[rowi]: + row = dommatrix[rowi] + for row2i in range(nterms): + # Still non-dominated? + if rowi != row2i and rowcount[rowi] and (rowcount[rowi] <= rowcount[row2i]): + row2 = dommatrix[row2i] + if all(row2[n] >= row[n] for n in range(nl1)): + # row2 dominating row, remove row2 + rowcount[row2i] = 0 + anythingchanged = True + for primei, prime in enumerate(row2): + if prime: + # Make corresponding entry 0 + dommatrix[row2i][primei] = 0 + colcount[primei] -= 1 + + colcache = {} + + for coli in range(nl1): + # Still non-dominated? + if colcount[coli]: + if coli in colcache: + col = colcache[coli] + else: + col = [dommatrix[i][coli] for i in range(nterms)] + colcache[coli] = col + for col2i in range(nl1): + # Still non-dominated? + if coli != col2i and colcount[col2i] and (colcount[coli] >= colcount[col2i]): + if col2i in colcache: + col2 = colcache[col2i] + else: + col2 = [dommatrix[i][col2i] for i in range(nterms)] + colcache[col2i] = col2 + if all(col[n] >= col2[n] for n in range(nterms)): + # col dominating col2, remove col2 + colcount[col2i] = 0 + anythingchanged = True + for termi, term in enumerate(col2): + if term and dommatrix[termi][col2i]: + # Make corresponding entry 0 + dommatrix[termi][col2i] = 0 + rowcount[termi] -= 1 + + if not anythingchanged: + # Heuristically select the prime implicant covering most terms + maxterms = 0 + bestcolidx = -1 + for coli in range(nl1): + s = colcount[coli] + if s > maxterms: + bestcolidx = coli + maxterms = s + + # In case we found a prime implicant covering at least two terms + if bestcolidx != -1 and maxterms > 1: + for primei, prime in enumerate(l1): + if primei != bestcolidx: + for termi, term in enumerate(colcache[bestcolidx]): + if term and dommatrix[termi][primei]: + # Make corresponding entry 0 + dommatrix[termi][primei] = 0 + anythingchanged = True + rowcount[termi] -= 1 + colcount[primei] -= 1 + + return [l1[i] for i in range(nl1) if colcount[i]] + + +def _input_to_binlist(inputlist, variables): + binlist = [] + bits = len(variables) + for val in inputlist: + if isinstance(val, int): + binlist.append(ibin(val, bits)) + elif isinstance(val, dict): + nonspecvars = list(variables) + for key in val.keys(): + nonspecvars.remove(key) + for t in product((0, 1), repeat=len(nonspecvars)): + d = dict(zip(nonspecvars, t)) + d.update(val) + binlist.append([d[v] for v in variables]) + elif isinstance(val, (list, tuple)): + if len(val) != bits: + raise ValueError("Each term must contain {bits} bits as there are" + "\n{bits} variables (or be an integer)." + "".format(bits=bits)) + binlist.append(list(val)) + else: + raise TypeError("A term list can only contain lists," + " ints or dicts.") + return binlist + + +def SOPform(variables, minterms, dontcares=None): + """ + The SOPform function uses simplified_pairs and a redundant group- + eliminating algorithm to convert the list of all input combos that + generate '1' (the minterms) into the smallest sum-of-products form. + + The variables must be given as the first argument. + + Return a logical :py:class:`~.Or` function (i.e., the "sum of products" or + "SOP" form) that gives the desired outcome. If there are inputs that can + be ignored, pass them as a list, too. + + The result will be one of the (perhaps many) functions that satisfy + the conditions. + + Examples + ======== + + >>> from sympy.logic import SOPform + >>> from sympy import symbols + >>> w, x, y, z = symbols('w x y z') + >>> minterms = [[0, 0, 0, 1], [0, 0, 1, 1], + ... [0, 1, 1, 1], [1, 0, 1, 1], [1, 1, 1, 1]] + >>> dontcares = [[0, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 1]] + >>> SOPform([w, x, y, z], minterms, dontcares) + (y & z) | (~w & ~x) + + The terms can also be represented as integers: + + >>> minterms = [1, 3, 7, 11, 15] + >>> dontcares = [0, 2, 5] + >>> SOPform([w, x, y, z], minterms, dontcares) + (y & z) | (~w & ~x) + + They can also be specified using dicts, which does not have to be fully + specified: + + >>> minterms = [{w: 0, x: 1}, {y: 1, z: 1, x: 0}] + >>> SOPform([w, x, y, z], minterms) + (x & ~w) | (y & z & ~x) + + Or a combination: + + >>> minterms = [4, 7, 11, [1, 1, 1, 1]] + >>> dontcares = [{w : 0, x : 0, y: 0}, 5] + >>> SOPform([w, x, y, z], minterms, dontcares) + (w & y & z) | (~w & ~y) | (x & z & ~w) + + See also + ======== + + POSform + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Quine-McCluskey_algorithm + .. [2] https://en.wikipedia.org/wiki/Don%27t-care_term + + """ + if not minterms: + return false + + variables = tuple(map(sympify, variables)) + + + minterms = _input_to_binlist(minterms, variables) + dontcares = _input_to_binlist((dontcares or []), variables) + for d in dontcares: + if d in minterms: + raise ValueError('%s in minterms is also in dontcares' % d) + + return _sop_form(variables, minterms, dontcares) + + +def _sop_form(variables, minterms, dontcares): + new = _simplified_pairs(minterms + dontcares) + essential = _rem_redundancy(new, minterms) + return Or(*[_convert_to_varsSOP(x, variables) for x in essential]) + + +def POSform(variables, minterms, dontcares=None): + """ + The POSform function uses simplified_pairs and a redundant-group + eliminating algorithm to convert the list of all input combinations + that generate '1' (the minterms) into the smallest product-of-sums form. + + The variables must be given as the first argument. + + Return a logical :py:class:`~.And` function (i.e., the "product of sums" + or "POS" form) that gives the desired outcome. If there are inputs that can + be ignored, pass them as a list, too. + + The result will be one of the (perhaps many) functions that satisfy + the conditions. + + Examples + ======== + + >>> from sympy.logic import POSform + >>> from sympy import symbols + >>> w, x, y, z = symbols('w x y z') + >>> minterms = [[0, 0, 0, 1], [0, 0, 1, 1], [0, 1, 1, 1], + ... [1, 0, 1, 1], [1, 1, 1, 1]] + >>> dontcares = [[0, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 1]] + >>> POSform([w, x, y, z], minterms, dontcares) + z & (y | ~w) + + The terms can also be represented as integers: + + >>> minterms = [1, 3, 7, 11, 15] + >>> dontcares = [0, 2, 5] + >>> POSform([w, x, y, z], minterms, dontcares) + z & (y | ~w) + + They can also be specified using dicts, which does not have to be fully + specified: + + >>> minterms = [{w: 0, x: 1}, {y: 1, z: 1, x: 0}] + >>> POSform([w, x, y, z], minterms) + (x | y) & (x | z) & (~w | ~x) + + Or a combination: + + >>> minterms = [4, 7, 11, [1, 1, 1, 1]] + >>> dontcares = [{w : 0, x : 0, y: 0}, 5] + >>> POSform([w, x, y, z], minterms, dontcares) + (w | x) & (y | ~w) & (z | ~y) + + See also + ======== + + SOPform + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Quine-McCluskey_algorithm + .. [2] https://en.wikipedia.org/wiki/Don%27t-care_term + + """ + if not minterms: + return false + + variables = tuple(map(sympify, variables)) + minterms = _input_to_binlist(minterms, variables) + dontcares = _input_to_binlist((dontcares or []), variables) + for d in dontcares: + if d in minterms: + raise ValueError('%s in minterms is also in dontcares' % d) + + maxterms = [] + for t in product((0, 1), repeat=len(variables)): + t = list(t) + if (t not in minterms) and (t not in dontcares): + maxterms.append(t) + + new = _simplified_pairs(maxterms + dontcares) + essential = _rem_redundancy(new, maxterms) + return And(*[_convert_to_varsPOS(x, variables) for x in essential]) + + +def ANFform(variables, truthvalues): + """ + The ANFform function converts the list of truth values to + Algebraic Normal Form (ANF). + + The variables must be given as the first argument. + + Return True, False, logical :py:class:`~.And` function (i.e., the + "Zhegalkin monomial") or logical :py:class:`~.Xor` function (i.e., + the "Zhegalkin polynomial"). When True and False + are represented by 1 and 0, respectively, then + :py:class:`~.And` is multiplication and :py:class:`~.Xor` is addition. + + Formally a "Zhegalkin monomial" is the product (logical + And) of a finite set of distinct variables, including + the empty set whose product is denoted 1 (True). + A "Zhegalkin polynomial" is the sum (logical Xor) of a + set of Zhegalkin monomials, with the empty set denoted + by 0 (False). + + Parameters + ========== + + variables : list of variables + truthvalues : list of 1's and 0's (result column of truth table) + + Examples + ======== + >>> from sympy.logic.boolalg import ANFform + >>> from sympy.abc import x, y + >>> ANFform([x], [1, 0]) + x ^ True + >>> ANFform([x, y], [0, 1, 1, 1]) + x ^ y ^ (x & y) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Zhegalkin_polynomial + + """ + + n_vars = len(variables) + n_values = len(truthvalues) + + if n_values != 2 ** n_vars: + raise ValueError("The number of truth values must be equal to 2^%d, " + "got %d" % (n_vars, n_values)) + + variables = tuple(map(sympify, variables)) + + coeffs = anf_coeffs(truthvalues) + terms = [] + + for i, t in enumerate(product((0, 1), repeat=n_vars)): + if coeffs[i] == 1: + terms.append(t) + + return Xor(*[_convert_to_varsANF(x, variables) for x in terms], + remove_true=False) + + +def anf_coeffs(truthvalues): + """ + Convert a list of truth values of some boolean expression + to the list of coefficients of the polynomial mod 2 (exclusive + disjunction) representing the boolean expression in ANF + (i.e., the "Zhegalkin polynomial"). + + There are `2^n` possible Zhegalkin monomials in `n` variables, since + each monomial is fully specified by the presence or absence of + each variable. + + We can enumerate all the monomials. For example, boolean + function with four variables ``(a, b, c, d)`` can contain + up to `2^4 = 16` monomials. The 13-th monomial is the + product ``a & b & d``, because 13 in binary is 1, 1, 0, 1. + + A given monomial's presence or absence in a polynomial corresponds + to that monomial's coefficient being 1 or 0 respectively. + + Examples + ======== + >>> from sympy.logic.boolalg import anf_coeffs, bool_monomial, Xor + >>> from sympy.abc import a, b, c + >>> truthvalues = [0, 1, 1, 0, 0, 1, 0, 1] + >>> coeffs = anf_coeffs(truthvalues) + >>> coeffs + [0, 1, 1, 0, 0, 0, 1, 0] + >>> polynomial = Xor(*[ + ... bool_monomial(k, [a, b, c]) + ... for k, coeff in enumerate(coeffs) if coeff == 1 + ... ]) + >>> polynomial + b ^ c ^ (a & b) + + """ + + s = '{:b}'.format(len(truthvalues)) + n = len(s) - 1 + + if len(truthvalues) != 2**n: + raise ValueError("The number of truth values must be a power of two, " + "got %d" % len(truthvalues)) + + coeffs = [[v] for v in truthvalues] + + for i in range(n): + tmp = [] + for j in range(2 ** (n-i-1)): + tmp.append(coeffs[2*j] + + list(map(lambda x, y: x^y, coeffs[2*j], coeffs[2*j+1]))) + coeffs = tmp + + return coeffs[0] + + +def bool_minterm(k, variables): + """ + Return the k-th minterm. + + Minterms are numbered by a binary encoding of the complementation + pattern of the variables. This convention assigns the value 1 to + the direct form and 0 to the complemented form. + + Parameters + ========== + + k : int or list of 1's and 0's (complementation pattern) + variables : list of variables + + Examples + ======== + + >>> from sympy.logic.boolalg import bool_minterm + >>> from sympy.abc import x, y, z + >>> bool_minterm([1, 0, 1], [x, y, z]) + x & z & ~y + >>> bool_minterm(6, [x, y, z]) + x & y & ~z + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Canonical_normal_form#Indexing_minterms + + """ + if isinstance(k, int): + k = ibin(k, len(variables)) + variables = tuple(map(sympify, variables)) + return _convert_to_varsSOP(k, variables) + + +def bool_maxterm(k, variables): + """ + Return the k-th maxterm. + + Each maxterm is assigned an index based on the opposite + conventional binary encoding used for minterms. The maxterm + convention assigns the value 0 to the direct form and 1 to + the complemented form. + + Parameters + ========== + + k : int or list of 1's and 0's (complementation pattern) + variables : list of variables + + Examples + ======== + >>> from sympy.logic.boolalg import bool_maxterm + >>> from sympy.abc import x, y, z + >>> bool_maxterm([1, 0, 1], [x, y, z]) + y | ~x | ~z + >>> bool_maxterm(6, [x, y, z]) + z | ~x | ~y + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Canonical_normal_form#Indexing_maxterms + + """ + if isinstance(k, int): + k = ibin(k, len(variables)) + variables = tuple(map(sympify, variables)) + return _convert_to_varsPOS(k, variables) + + +def bool_monomial(k, variables): + """ + Return the k-th monomial. + + Monomials are numbered by a binary encoding of the presence and + absences of the variables. This convention assigns the value + 1 to the presence of variable and 0 to the absence of variable. + + Each boolean function can be uniquely represented by a + Zhegalkin Polynomial (Algebraic Normal Form). The Zhegalkin + Polynomial of the boolean function with `n` variables can contain + up to `2^n` monomials. We can enumerate all the monomials. + Each monomial is fully specified by the presence or absence + of each variable. + + For example, boolean function with four variables ``(a, b, c, d)`` + can contain up to `2^4 = 16` monomials. The 13-th monomial is the + product ``a & b & d``, because 13 in binary is 1, 1, 0, 1. + + Parameters + ========== + + k : int or list of 1's and 0's + variables : list of variables + + Examples + ======== + >>> from sympy.logic.boolalg import bool_monomial + >>> from sympy.abc import x, y, z + >>> bool_monomial([1, 0, 1], [x, y, z]) + x & z + >>> bool_monomial(6, [x, y, z]) + x & y + + """ + if isinstance(k, int): + k = ibin(k, len(variables)) + variables = tuple(map(sympify, variables)) + return _convert_to_varsANF(k, variables) + + +def _find_predicates(expr): + """Helper to find logical predicates in BooleanFunctions. + + A logical predicate is defined here as anything within a BooleanFunction + that is not a BooleanFunction itself. + + """ + if not isinstance(expr, BooleanFunction): + return {expr} + return set().union(*(map(_find_predicates, expr.args))) + + +def simplify_logic(expr, form=None, deep=True, force=False, dontcare=None): + """ + This function simplifies a boolean function to its simplified version + in SOP or POS form. The return type is an :py:class:`~.Or` or + :py:class:`~.And` object in SymPy. + + Parameters + ========== + + expr : Boolean + + form : string (``'cnf'`` or ``'dnf'``) or ``None`` (default). + If ``'cnf'`` or ``'dnf'``, the simplest expression in the corresponding + normal form is returned; if ``None``, the answer is returned + according to the form with fewest args (in CNF by default). + + deep : bool (default ``True``) + Indicates whether to recursively simplify any + non-boolean functions contained within the input. + + force : bool (default ``False``) + As the simplifications require exponential time in the number + of variables, there is by default a limit on expressions with + 8 variables. When the expression has more than 8 variables + only symbolical simplification (controlled by ``deep``) is + made. By setting ``force`` to ``True``, this limit is removed. Be + aware that this can lead to very long simplification times. + + dontcare : Boolean + Optimize expression under the assumption that inputs where this + expression is true are don't care. This is useful in e.g. Piecewise + conditions, where later conditions do not need to consider inputs that + are converted by previous conditions. For example, if a previous + condition is ``And(A, B)``, the simplification of expr can be made + with don't cares for ``And(A, B)``. + + Examples + ======== + + >>> from sympy.logic import simplify_logic + >>> from sympy.abc import x, y, z + >>> b = (~x & ~y & ~z) | ( ~x & ~y & z) + >>> simplify_logic(b) + ~x & ~y + >>> simplify_logic(x | y, dontcare=y) + x + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Don%27t-care_term + + """ + + if form not in (None, 'cnf', 'dnf'): + raise ValueError("form can be cnf or dnf only") + expr = sympify(expr) + # check for quick exit if form is given: right form and all args are + # literal and do not involve Not + if form: + form_ok = False + if form == 'cnf': + form_ok = is_cnf(expr) + elif form == 'dnf': + form_ok = is_dnf(expr) + + if form_ok and all(is_literal(a) + for a in expr.args): + return expr + from sympy.core.relational import Relational + if deep: + variables = expr.atoms(Relational) + from sympy.simplify.simplify import simplify + s = tuple(map(simplify, variables)) + expr = expr.xreplace(dict(zip(variables, s))) + if not isinstance(expr, BooleanFunction): + return expr + # Replace Relationals with Dummys to possibly + # reduce the number of variables + repl = {} + undo = {} + from sympy.core.symbol import Dummy + variables = expr.atoms(Relational) + if dontcare is not None: + dontcare = sympify(dontcare) + variables.update(dontcare.atoms(Relational)) + while variables: + var = variables.pop() + if var.is_Relational: + d = Dummy() + undo[d] = var + repl[var] = d + nvar = var.negated + if nvar in variables: + repl[nvar] = Not(d) + variables.remove(nvar) + + expr = expr.xreplace(repl) + + if dontcare is not None: + dontcare = dontcare.xreplace(repl) + + # Get new variables after replacing + variables = _find_predicates(expr) + if not force and len(variables) > 8: + return expr.xreplace(undo) + if dontcare is not None: + # Add variables from dontcare + dcvariables = _find_predicates(dontcare) + variables.update(dcvariables) + # if too many restore to variables only + if not force and len(variables) > 8: + variables = _find_predicates(expr) + dontcare = None + # group into constants and variable values + c, v = sift(ordered(variables), lambda x: x in (True, False), binary=True) + variables = c + v + # standardize constants to be 1 or 0 in keeping with truthtable + c = [1 if i == True else 0 for i in c] + truthtable = _get_truthtable(v, expr, c) + if dontcare is not None: + dctruthtable = _get_truthtable(v, dontcare, c) + truthtable = [t for t in truthtable if t not in dctruthtable] + else: + dctruthtable = [] + big = len(truthtable) >= (2 ** (len(variables) - 1)) + if form == 'dnf' or form is None and big: + return _sop_form(variables, truthtable, dctruthtable).xreplace(undo) + return POSform(variables, truthtable, dctruthtable).xreplace(undo) + + +def _get_truthtable(variables, expr, const): + """ Return a list of all combinations leading to a True result for ``expr``. + """ + _variables = variables.copy() + def _get_tt(inputs): + if _variables: + v = _variables.pop() + tab = [[i[0].xreplace({v: false}), [0] + i[1]] for i in inputs if i[0] is not false] + tab.extend([[i[0].xreplace({v: true}), [1] + i[1]] for i in inputs if i[0] is not false]) + return _get_tt(tab) + return inputs + res = [const + k[1] for k in _get_tt([[expr, []]]) if k[0]] + if res == [[]]: + return [] + else: + return res + + +def _finger(eq): + """ + Assign a 5-item fingerprint to each symbol in the equation: + [ + # of times it appeared as a Symbol; + # of times it appeared as a Not(symbol); + # of times it appeared as a Symbol in an And or Or; + # of times it appeared as a Not(Symbol) in an And or Or; + a sorted tuple of tuples, (i, j, k), where i is the number of arguments + in an And or Or with which it appeared as a Symbol, and j is + the number of arguments that were Not(Symbol); k is the number + of times that (i, j) was seen. + ] + + Examples + ======== + + >>> from sympy.logic.boolalg import _finger as finger + >>> from sympy import And, Or, Not, Xor, to_cnf, symbols + >>> from sympy.abc import a, b, x, y + >>> eq = Or(And(Not(y), a), And(Not(y), b), And(x, y)) + >>> dict(finger(eq)) + {(0, 0, 1, 0, ((2, 0, 1),)): [x], + (0, 0, 1, 0, ((2, 1, 1),)): [a, b], + (0, 0, 1, 2, ((2, 0, 1),)): [y]} + >>> dict(finger(x & ~y)) + {(0, 1, 0, 0, ()): [y], (1, 0, 0, 0, ()): [x]} + + In the following, the (5, 2, 6) means that there were 6 Or + functions in which a symbol appeared as itself amongst 5 arguments in + which there were also 2 negated symbols, e.g. ``(a0 | a1 | a2 | ~a3 | ~a4)`` + is counted once for a0, a1 and a2. + + >>> dict(finger(to_cnf(Xor(*symbols('a:5'))))) + {(0, 0, 8, 8, ((5, 0, 1), (5, 2, 6), (5, 4, 1))): [a0, a1, a2, a3, a4]} + + The equation must not have more than one level of nesting: + + >>> dict(finger(And(Or(x, y), y))) + {(0, 0, 1, 0, ((2, 0, 1),)): [x], (1, 0, 1, 0, ((2, 0, 1),)): [y]} + >>> dict(finger(And(Or(x, And(a, x)), y))) + Traceback (most recent call last): + ... + NotImplementedError: unexpected level of nesting + + So y and x have unique fingerprints, but a and b do not. + """ + f = eq.free_symbols + d = dict(list(zip(f, [[0]*4 + [defaultdict(int)] for fi in f]))) + for a in eq.args: + if a.is_Symbol: + d[a][0] += 1 + elif a.is_Not: + d[a.args[0]][1] += 1 + else: + o = len(a.args), sum(isinstance(ai, Not) for ai in a.args) + for ai in a.args: + if ai.is_Symbol: + d[ai][2] += 1 + d[ai][-1][o] += 1 + elif ai.is_Not: + d[ai.args[0]][3] += 1 + else: + raise NotImplementedError('unexpected level of nesting') + inv = defaultdict(list) + for k, v in ordered(iter(d.items())): + v[-1] = tuple(sorted([i + (j,) for i, j in v[-1].items()])) + inv[tuple(v)].append(k) + return inv + + +def bool_map(bool1, bool2): + """ + Return the simplified version of *bool1*, and the mapping of variables + that makes the two expressions *bool1* and *bool2* represent the same + logical behaviour for some correspondence between the variables + of each. + If more than one mappings of this sort exist, one of them + is returned. + + For example, ``And(x, y)`` is logically equivalent to ``And(a, b)`` for + the mapping ``{x: a, y: b}`` or ``{x: b, y: a}``. + If no such mapping exists, return ``False``. + + Examples + ======== + + >>> from sympy import SOPform, bool_map, Or, And, Not, Xor + >>> from sympy.abc import w, x, y, z, a, b, c, d + >>> function1 = SOPform([x, z, y],[[1, 0, 1], [0, 0, 1]]) + >>> function2 = SOPform([a, b, c],[[1, 0, 1], [1, 0, 0]]) + >>> bool_map(function1, function2) + (y & ~z, {y: a, z: b}) + + The results are not necessarily unique, but they are canonical. Here, + ``(w, z)`` could be ``(a, d)`` or ``(d, a)``: + + >>> eq = Or(And(Not(y), w), And(Not(y), z), And(x, y)) + >>> eq2 = Or(And(Not(c), a), And(Not(c), d), And(b, c)) + >>> bool_map(eq, eq2) + ((x & y) | (w & ~y) | (z & ~y), {w: a, x: b, y: c, z: d}) + >>> eq = And(Xor(a, b), c, And(c,d)) + >>> bool_map(eq, eq.subs(c, x)) + (c & d & (a | b) & (~a | ~b), {a: a, b: b, c: d, d: x}) + + """ + + def match(function1, function2): + """Return the mapping that equates variables between two + simplified boolean expressions if possible. + + By "simplified" we mean that a function has been denested + and is either an And (or an Or) whose arguments are either + symbols (x), negated symbols (Not(x)), or Or (or an And) whose + arguments are only symbols or negated symbols. For example, + ``And(x, Not(y), Or(w, Not(z)))``. + + Basic.match is not robust enough (see issue 4835) so this is + a workaround that is valid for simplified boolean expressions + """ + + # do some quick checks + if function1.__class__ != function2.__class__: + return None # maybe simplification makes them the same? + if len(function1.args) != len(function2.args): + return None # maybe simplification makes them the same? + if function1.is_Symbol: + return {function1: function2} + + # get the fingerprint dictionaries + f1 = _finger(function1) + f2 = _finger(function2) + + # more quick checks + if len(f1) != len(f2): + return False + + # assemble the match dictionary if possible + matchdict = {} + for k in f1.keys(): + if k not in f2: + return False + if len(f1[k]) != len(f2[k]): + return False + for i, x in enumerate(f1[k]): + matchdict[x] = f2[k][i] + return matchdict + + a = simplify_logic(bool1) + b = simplify_logic(bool2) + m = match(a, b) + if m: + return a, m + return m + + +def _apply_patternbased_simplification(rv, patterns, measure, + dominatingvalue, + replacementvalue=None, + threeterm_patterns=None): + """ + Replace patterns of Relational + + Parameters + ========== + + rv : Expr + Boolean expression + + patterns : tuple + Tuple of tuples, with (pattern to simplify, simplified pattern) with + two terms. + + measure : function + Simplification measure. + + dominatingvalue : Boolean or ``None`` + The dominating value for the function of consideration. + For example, for :py:class:`~.And` ``S.false`` is dominating. + As soon as one expression is ``S.false`` in :py:class:`~.And`, + the whole expression is ``S.false``. + + replacementvalue : Boolean or ``None``, optional + The resulting value for the whole expression if one argument + evaluates to ``dominatingvalue``. + For example, for :py:class:`~.Nand` ``S.false`` is dominating, but + in this case the resulting value is ``S.true``. Default is ``None``. + If ``replacementvalue`` is ``None`` and ``dominatingvalue`` is not + ``None``, ``replacementvalue = dominatingvalue``. + + threeterm_patterns : tuple, optional + Tuple of tuples, with (pattern to simplify, simplified pattern) with + three terms. + + """ + from sympy.core.relational import Relational, _canonical + + if replacementvalue is None and dominatingvalue is not None: + replacementvalue = dominatingvalue + # Use replacement patterns for Relationals + Rel, nonRel = sift(rv.args, lambda i: isinstance(i, Relational), + binary=True) + if len(Rel) <= 1: + return rv + Rel, nonRealRel = sift(Rel, lambda i: not any(s.is_real is False + for s in i.free_symbols), + binary=True) + Rel = [i.canonical for i in Rel] + + if threeterm_patterns and len(Rel) >= 3: + Rel = _apply_patternbased_threeterm_simplification(Rel, + threeterm_patterns, rv.func, dominatingvalue, + replacementvalue, measure) + + Rel = _apply_patternbased_twoterm_simplification(Rel, patterns, + rv.func, dominatingvalue, replacementvalue, measure) + + rv = rv.func(*([_canonical(i) for i in ordered(Rel)] + + nonRel + nonRealRel)) + return rv + + +def _apply_patternbased_twoterm_simplification(Rel, patterns, func, + dominatingvalue, + replacementvalue, + measure): + """ Apply pattern-based two-term simplification.""" + from sympy.functions.elementary.miscellaneous import Min, Max + from sympy.core.relational import Ge, Gt, _Inequality + changed = True + while changed and len(Rel) >= 2: + changed = False + # Use only < or <= + Rel = [r.reversed if isinstance(r, (Ge, Gt)) else r for r in Rel] + # Sort based on ordered + Rel = list(ordered(Rel)) + # Eq and Ne must be tested reversed as well + rtmp = [(r, ) if isinstance(r, _Inequality) else (r, r.reversed) for r in Rel] + # Create a list of possible replacements + results = [] + # Try all combinations of possibly reversed relational + for ((i, pi), (j, pj)) in combinations(enumerate(rtmp), 2): + for pattern, simp in patterns: + res = [] + for p1, p2 in product(pi, pj): + # use SymPy matching + oldexpr = Tuple(p1, p2) + tmpres = oldexpr.match(pattern) + if tmpres: + res.append((tmpres, oldexpr)) + + if res: + for tmpres, oldexpr in res: + # we have a matching, compute replacement + np = simp.xreplace(tmpres) + if np == dominatingvalue: + # if dominatingvalue, the whole expression + # will be replacementvalue + return [replacementvalue] + # add replacement + if not isinstance(np, ITE) and not np.has(Min, Max): + # We only want to use ITE and Min/Max replacements if + # they simplify to a relational + costsaving = measure(func(*oldexpr.args)) - measure(np) + if costsaving > 0: + results.append((costsaving, ([i, j], np))) + if results: + # Sort results based on complexity + results = sorted(results, + key=lambda pair: pair[0], reverse=True) + # Replace the one providing most simplification + replacement = results[0][1] + idx, newrel = replacement + idx.sort() + # Remove the old relationals + for index in reversed(idx): + del Rel[index] + if dominatingvalue is None or newrel != Not(dominatingvalue): + # Insert the new one (no need to insert a value that will + # not affect the result) + if newrel.func == func: + for a in newrel.args: + Rel.append(a) + else: + Rel.append(newrel) + # We did change something so try again + changed = True + return Rel + + +def _apply_patternbased_threeterm_simplification(Rel, patterns, func, + dominatingvalue, + replacementvalue, + measure): + """ Apply pattern-based three-term simplification.""" + from sympy.functions.elementary.miscellaneous import Min, Max + from sympy.core.relational import Le, Lt, _Inequality + changed = True + while changed and len(Rel) >= 3: + changed = False + # Use only > or >= + Rel = [r.reversed if isinstance(r, (Le, Lt)) else r for r in Rel] + # Sort based on ordered + Rel = list(ordered(Rel)) + # Create a list of possible replacements + results = [] + # Eq and Ne must be tested reversed as well + rtmp = [(r, ) if isinstance(r, _Inequality) else (r, r.reversed) for r in Rel] + # Try all combinations of possibly reversed relational + for ((i, pi), (j, pj), (k, pk)) in permutations(enumerate(rtmp), 3): + for pattern, simp in patterns: + res = [] + for p1, p2, p3 in product(pi, pj, pk): + # use SymPy matching + oldexpr = Tuple(p1, p2, p3) + tmpres = oldexpr.match(pattern) + if tmpres: + res.append((tmpres, oldexpr)) + + if res: + for tmpres, oldexpr in res: + # we have a matching, compute replacement + np = simp.xreplace(tmpres) + if np == dominatingvalue: + # if dominatingvalue, the whole expression + # will be replacementvalue + return [replacementvalue] + # add replacement + if not isinstance(np, ITE) and not np.has(Min, Max): + # We only want to use ITE and Min/Max replacements if + # they simplify to a relational + costsaving = measure(func(*oldexpr.args)) - measure(np) + if costsaving > 0: + results.append((costsaving, ([i, j, k], np))) + if results: + # Sort results based on complexity + results = sorted(results, + key=lambda pair: pair[0], reverse=True) + # Replace the one providing most simplification + replacement = results[0][1] + idx, newrel = replacement + idx.sort() + # Remove the old relationals + for index in reversed(idx): + del Rel[index] + if dominatingvalue is None or newrel != Not(dominatingvalue): + # Insert the new one (no need to insert a value that will + # not affect the result) + if newrel.func == func: + for a in newrel.args: + Rel.append(a) + else: + Rel.append(newrel) + # We did change something so try again + changed = True + return Rel + + +@cacheit +def _simplify_patterns_and(): + """ Two-term patterns for And.""" + + from sympy.core import Wild + from sympy.core.relational import Eq, Ne, Ge, Gt, Le, Lt + from sympy.functions.elementary.complexes import Abs + from sympy.functions.elementary.miscellaneous import Min, Max + a = Wild('a') + b = Wild('b') + c = Wild('c') + # Relationals patterns should be in alphabetical order + # (pattern1, pattern2, simplified) + # Do not use Ge, Gt + _matchers_and = ((Tuple(Eq(a, b), Lt(a, b)), false), + #(Tuple(Eq(a, b), Lt(b, a)), S.false), + #(Tuple(Le(b, a), Lt(a, b)), S.false), + #(Tuple(Lt(b, a), Le(a, b)), S.false), + (Tuple(Lt(b, a), Lt(a, b)), false), + (Tuple(Eq(a, b), Le(b, a)), Eq(a, b)), + #(Tuple(Eq(a, b), Le(a, b)), Eq(a, b)), + #(Tuple(Le(b, a), Lt(b, a)), Gt(a, b)), + (Tuple(Le(b, a), Le(a, b)), Eq(a, b)), + #(Tuple(Le(b, a), Ne(a, b)), Gt(a, b)), + #(Tuple(Lt(b, a), Ne(a, b)), Gt(a, b)), + (Tuple(Le(a, b), Lt(a, b)), Lt(a, b)), + (Tuple(Le(a, b), Ne(a, b)), Lt(a, b)), + (Tuple(Lt(a, b), Ne(a, b)), Lt(a, b)), + # Sign + (Tuple(Eq(a, b), Eq(a, -b)), And(Eq(a, S.Zero), Eq(b, S.Zero))), + # Min/Max/ITE + (Tuple(Le(b, a), Le(c, a)), Ge(a, Max(b, c))), + (Tuple(Le(b, a), Lt(c, a)), ITE(b > c, Ge(a, b), Gt(a, c))), + (Tuple(Lt(b, a), Lt(c, a)), Gt(a, Max(b, c))), + (Tuple(Le(a, b), Le(a, c)), Le(a, Min(b, c))), + (Tuple(Le(a, b), Lt(a, c)), ITE(b < c, Le(a, b), Lt(a, c))), + (Tuple(Lt(a, b), Lt(a, c)), Lt(a, Min(b, c))), + (Tuple(Le(a, b), Le(c, a)), ITE(Eq(b, c), Eq(a, b), ITE(b < c, false, And(Le(a, b), Ge(a, c))))), + (Tuple(Le(c, a), Le(a, b)), ITE(Eq(b, c), Eq(a, b), ITE(b < c, false, And(Le(a, b), Ge(a, c))))), + (Tuple(Lt(a, b), Lt(c, a)), ITE(b < c, false, And(Lt(a, b), Gt(a, c)))), + (Tuple(Lt(c, a), Lt(a, b)), ITE(b < c, false, And(Lt(a, b), Gt(a, c)))), + (Tuple(Le(a, b), Lt(c, a)), ITE(b <= c, false, And(Le(a, b), Gt(a, c)))), + (Tuple(Le(c, a), Lt(a, b)), ITE(b <= c, false, And(Lt(a, b), Ge(a, c)))), + (Tuple(Eq(a, b), Eq(a, c)), ITE(Eq(b, c), Eq(a, b), false)), + (Tuple(Lt(a, b), Lt(-b, a)), ITE(b > 0, Lt(Abs(a), b), false)), + (Tuple(Le(a, b), Le(-b, a)), ITE(b >= 0, Le(Abs(a), b), false)), + ) + return _matchers_and + + +@cacheit +def _simplify_patterns_and3(): + """ Three-term patterns for And.""" + + from sympy.core import Wild + from sympy.core.relational import Eq, Ge, Gt + + a = Wild('a') + b = Wild('b') + c = Wild('c') + # Relationals patterns should be in alphabetical order + # (pattern1, pattern2, pattern3, simplified) + # Do not use Le, Lt + _matchers_and = ((Tuple(Ge(a, b), Ge(b, c), Gt(c, a)), false), + (Tuple(Ge(a, b), Gt(b, c), Gt(c, a)), false), + (Tuple(Gt(a, b), Gt(b, c), Gt(c, a)), false), + # (Tuple(Ge(c, a), Gt(a, b), Gt(b, c)), S.false), + # Lower bound relations + # Commented out combinations that does not simplify + (Tuple(Ge(a, b), Ge(a, c), Ge(b, c)), And(Ge(a, b), Ge(b, c))), + (Tuple(Ge(a, b), Ge(a, c), Gt(b, c)), And(Ge(a, b), Gt(b, c))), + # (Tuple(Ge(a, b), Gt(a, c), Ge(b, c)), And(Ge(a, b), Ge(b, c))), + (Tuple(Ge(a, b), Gt(a, c), Gt(b, c)), And(Ge(a, b), Gt(b, c))), + # (Tuple(Gt(a, b), Ge(a, c), Ge(b, c)), And(Gt(a, b), Ge(b, c))), + (Tuple(Ge(a, c), Gt(a, b), Gt(b, c)), And(Gt(a, b), Gt(b, c))), + (Tuple(Ge(b, c), Gt(a, b), Gt(a, c)), And(Gt(a, b), Ge(b, c))), + (Tuple(Gt(a, b), Gt(a, c), Gt(b, c)), And(Gt(a, b), Gt(b, c))), + # Upper bound relations + # Commented out combinations that does not simplify + (Tuple(Ge(b, a), Ge(c, a), Ge(b, c)), And(Ge(c, a), Ge(b, c))), + (Tuple(Ge(b, a), Ge(c, a), Gt(b, c)), And(Ge(c, a), Gt(b, c))), + # (Tuple(Ge(b, a), Gt(c, a), Ge(b, c)), And(Gt(c, a), Ge(b, c))), + (Tuple(Ge(b, a), Gt(c, a), Gt(b, c)), And(Gt(c, a), Gt(b, c))), + # (Tuple(Gt(b, a), Ge(c, a), Ge(b, c)), And(Ge(c, a), Ge(b, c))), + (Tuple(Ge(c, a), Gt(b, a), Gt(b, c)), And(Ge(c, a), Gt(b, c))), + (Tuple(Ge(b, c), Gt(b, a), Gt(c, a)), And(Gt(c, a), Ge(b, c))), + (Tuple(Gt(b, a), Gt(c, a), Gt(b, c)), And(Gt(c, a), Gt(b, c))), + # Circular relation + (Tuple(Ge(a, b), Ge(b, c), Ge(c, a)), And(Eq(a, b), Eq(b, c))), + ) + return _matchers_and + + +@cacheit +def _simplify_patterns_or(): + """ Two-term patterns for Or.""" + + from sympy.core import Wild + from sympy.core.relational import Eq, Ne, Ge, Gt, Le, Lt + from sympy.functions.elementary.complexes import Abs + from sympy.functions.elementary.miscellaneous import Min, Max + a = Wild('a') + b = Wild('b') + c = Wild('c') + # Relationals patterns should be in alphabetical order + # (pattern1, pattern2, simplified) + # Do not use Ge, Gt + _matchers_or = ((Tuple(Le(b, a), Le(a, b)), true), + #(Tuple(Le(b, a), Lt(a, b)), true), + (Tuple(Le(b, a), Ne(a, b)), true), + #(Tuple(Le(a, b), Lt(b, a)), true), + #(Tuple(Le(a, b), Ne(a, b)), true), + #(Tuple(Eq(a, b), Le(b, a)), Ge(a, b)), + #(Tuple(Eq(a, b), Lt(b, a)), Ge(a, b)), + (Tuple(Eq(a, b), Le(a, b)), Le(a, b)), + (Tuple(Eq(a, b), Lt(a, b)), Le(a, b)), + #(Tuple(Le(b, a), Lt(b, a)), Ge(a, b)), + (Tuple(Lt(b, a), Lt(a, b)), Ne(a, b)), + (Tuple(Lt(b, a), Ne(a, b)), Ne(a, b)), + (Tuple(Le(a, b), Lt(a, b)), Le(a, b)), + #(Tuple(Lt(a, b), Ne(a, b)), Ne(a, b)), + (Tuple(Eq(a, b), Ne(a, c)), ITE(Eq(b, c), true, Ne(a, c))), + (Tuple(Ne(a, b), Ne(a, c)), ITE(Eq(b, c), Ne(a, b), true)), + # Min/Max/ITE + (Tuple(Le(b, a), Le(c, a)), Ge(a, Min(b, c))), + #(Tuple(Ge(b, a), Ge(c, a)), Ge(Min(b, c), a)), + (Tuple(Le(b, a), Lt(c, a)), ITE(b > c, Lt(c, a), Le(b, a))), + (Tuple(Lt(b, a), Lt(c, a)), Gt(a, Min(b, c))), + #(Tuple(Gt(b, a), Gt(c, a)), Gt(Min(b, c), a)), + (Tuple(Le(a, b), Le(a, c)), Le(a, Max(b, c))), + #(Tuple(Le(b, a), Le(c, a)), Le(Max(b, c), a)), + (Tuple(Le(a, b), Lt(a, c)), ITE(b >= c, Le(a, b), Lt(a, c))), + (Tuple(Lt(a, b), Lt(a, c)), Lt(a, Max(b, c))), + #(Tuple(Lt(b, a), Lt(c, a)), Lt(Max(b, c), a)), + (Tuple(Le(a, b), Le(c, a)), ITE(b >= c, true, Or(Le(a, b), Ge(a, c)))), + (Tuple(Le(c, a), Le(a, b)), ITE(b >= c, true, Or(Le(a, b), Ge(a, c)))), + (Tuple(Lt(a, b), Lt(c, a)), ITE(b > c, true, Or(Lt(a, b), Gt(a, c)))), + (Tuple(Lt(c, a), Lt(a, b)), ITE(b > c, true, Or(Lt(a, b), Gt(a, c)))), + (Tuple(Le(a, b), Lt(c, a)), ITE(b >= c, true, Or(Le(a, b), Gt(a, c)))), + (Tuple(Le(c, a), Lt(a, b)), ITE(b >= c, true, Or(Lt(a, b), Ge(a, c)))), + (Tuple(Lt(b, a), Lt(a, -b)), ITE(b >= 0, Gt(Abs(a), b), true)), + (Tuple(Le(b, a), Le(a, -b)), ITE(b > 0, Ge(Abs(a), b), true)), + ) + return _matchers_or + + +@cacheit +def _simplify_patterns_xor(): + """ Two-term patterns for Xor.""" + + from sympy.functions.elementary.miscellaneous import Min, Max + from sympy.core import Wild + from sympy.core.relational import Eq, Ne, Ge, Gt, Le, Lt + a = Wild('a') + b = Wild('b') + c = Wild('c') + # Relationals patterns should be in alphabetical order + # (pattern1, pattern2, simplified) + # Do not use Ge, Gt + _matchers_xor = (#(Tuple(Le(b, a), Lt(a, b)), true), + #(Tuple(Lt(b, a), Le(a, b)), true), + #(Tuple(Eq(a, b), Le(b, a)), Gt(a, b)), + #(Tuple(Eq(a, b), Lt(b, a)), Ge(a, b)), + (Tuple(Eq(a, b), Le(a, b)), Lt(a, b)), + (Tuple(Eq(a, b), Lt(a, b)), Le(a, b)), + (Tuple(Le(a, b), Lt(a, b)), Eq(a, b)), + (Tuple(Le(a, b), Le(b, a)), Ne(a, b)), + (Tuple(Le(b, a), Ne(a, b)), Le(a, b)), + # (Tuple(Lt(b, a), Lt(a, b)), Ne(a, b)), + (Tuple(Lt(b, a), Ne(a, b)), Lt(a, b)), + # (Tuple(Le(a, b), Lt(a, b)), Eq(a, b)), + # (Tuple(Le(a, b), Ne(a, b)), Ge(a, b)), + # (Tuple(Lt(a, b), Ne(a, b)), Gt(a, b)), + # Min/Max/ITE + (Tuple(Le(b, a), Le(c, a)), + And(Ge(a, Min(b, c)), Lt(a, Max(b, c)))), + (Tuple(Le(b, a), Lt(c, a)), + ITE(b > c, And(Gt(a, c), Lt(a, b)), + And(Ge(a, b), Le(a, c)))), + (Tuple(Lt(b, a), Lt(c, a)), + And(Gt(a, Min(b, c)), Le(a, Max(b, c)))), + (Tuple(Le(a, b), Le(a, c)), + And(Le(a, Max(b, c)), Gt(a, Min(b, c)))), + (Tuple(Le(a, b), Lt(a, c)), + ITE(b < c, And(Lt(a, c), Gt(a, b)), + And(Le(a, b), Ge(a, c)))), + (Tuple(Lt(a, b), Lt(a, c)), + And(Lt(a, Max(b, c)), Ge(a, Min(b, c)))), + ) + return _matchers_xor + + +def simplify_univariate(expr): + """return a simplified version of univariate boolean expression, else ``expr``""" + from sympy.functions.elementary.piecewise import Piecewise + from sympy.core.relational import Eq, Ne + if not isinstance(expr, BooleanFunction): + return expr + if expr.atoms(Eq, Ne): + return expr + c = expr + free = c.free_symbols + if len(free) != 1: + return c + x = free.pop() + ok, i = Piecewise((0, c), evaluate=False + )._intervals(x, err_on_Eq=True) + if not ok: + return c + if not i: + return false + args = [] + for a, b, _, _ in i: + if a is S.NegativeInfinity: + if b is S.Infinity: + c = true + else: + if c.subs(x, b) == True: + c = (x <= b) + else: + c = (x < b) + else: + incl_a = (c.subs(x, a) == True) + incl_b = (c.subs(x, b) == True) + if incl_a and incl_b: + if b.is_infinite: + c = (x >= a) + else: + c = And(a <= x, x <= b) + elif incl_a: + c = And(a <= x, x < b) + elif incl_b: + if b.is_infinite: + c = (x > a) + else: + c = And(a < x, x <= b) + else: + c = And(a < x, x < b) + args.append(c) + return Or(*args) + + +# Classes corresponding to logic gates +# Used in gateinputcount method +BooleanGates = (And, Or, Xor, Nand, Nor, Not, Xnor, ITE) + +def gateinputcount(expr): + """ + Return the total number of inputs for the logic gates realizing the + Boolean expression. + + Returns + ======= + + int + Number of gate inputs + + Note + ==== + + Not all Boolean functions count as gate here, only those that are + considered to be standard gates. These are: :py:class:`~.And`, + :py:class:`~.Or`, :py:class:`~.Xor`, :py:class:`~.Not`, and + :py:class:`~.ITE` (multiplexer). :py:class:`~.Nand`, :py:class:`~.Nor`, + and :py:class:`~.Xnor` will be evaluated to ``Not(And())`` etc. + + Examples + ======== + + >>> from sympy.logic import And, Or, Nand, Not, gateinputcount + >>> from sympy.abc import x, y, z + >>> expr = And(x, y) + >>> gateinputcount(expr) + 2 + >>> gateinputcount(Or(expr, z)) + 4 + + Note that ``Nand`` is automatically evaluated to ``Not(And())`` so + + >>> gateinputcount(Nand(x, y, z)) + 4 + >>> gateinputcount(Not(And(x, y, z))) + 4 + + Although this can be avoided by using ``evaluate=False`` + + >>> gateinputcount(Nand(x, y, z, evaluate=False)) + 3 + + Also note that a comparison will count as a Boolean variable: + + >>> gateinputcount(And(x > z, y >= 2)) + 2 + + As will a symbol: + >>> gateinputcount(x) + 0 + + """ + if not isinstance(expr, Boolean): + raise TypeError("Expression must be Boolean") + if isinstance(expr, BooleanGates): + return len(expr.args) + sum(gateinputcount(x) for x in expr.args) + return 0 diff --git a/.venv/lib/python3.13/site-packages/sympy/logic/inference.py b/.venv/lib/python3.13/site-packages/sympy/logic/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..c3798231c09ae351ea7e7026d622b834fea3e3fa --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/logic/inference.py @@ -0,0 +1,340 @@ +"""Inference in propositional logic""" + +from sympy.logic.boolalg import And, Not, conjuncts, to_cnf, BooleanFunction +from sympy.core.sorting import ordered +from sympy.core.sympify import sympify +from sympy.external.importtools import import_module + + +def literal_symbol(literal): + """ + The symbol in this literal (without the negation). + + Examples + ======== + + >>> from sympy.abc import A + >>> from sympy.logic.inference import literal_symbol + >>> literal_symbol(A) + A + >>> literal_symbol(~A) + A + + """ + + if literal is True or literal is False: + return literal + elif literal.is_Symbol: + return literal + elif literal.is_Not: + return literal_symbol(literal.args[0]) + else: + raise ValueError("Argument must be a boolean literal.") + + +def satisfiable(expr, algorithm=None, all_models=False, minimal=False, use_lra_theory=False): + """ + Check satisfiability of a propositional sentence. + Returns a model when it succeeds. + Returns {true: true} for trivially true expressions. + + On setting all_models to True, if given expr is satisfiable then + returns a generator of models. However, if expr is unsatisfiable + then returns a generator containing the single element False. + + Examples + ======== + + >>> from sympy.abc import A, B + >>> from sympy.logic.inference import satisfiable + >>> satisfiable(A & ~B) + {A: True, B: False} + >>> satisfiable(A & ~A) + False + >>> satisfiable(True) + {True: True} + >>> next(satisfiable(A & ~A, all_models=True)) + False + >>> models = satisfiable((A >> B) & B, all_models=True) + >>> next(models) + {A: False, B: True} + >>> next(models) + {A: True, B: True} + >>> def use_models(models): + ... for model in models: + ... if model: + ... # Do something with the model. + ... print(model) + ... else: + ... # Given expr is unsatisfiable. + ... print("UNSAT") + >>> use_models(satisfiable(A >> ~A, all_models=True)) + {A: False} + >>> use_models(satisfiable(A ^ A, all_models=True)) + UNSAT + + """ + if use_lra_theory: + if algorithm is not None and algorithm != "dpll2": + raise ValueError(f"Currently only dpll2 can handle using lra theory. {algorithm} is not handled.") + algorithm = "dpll2" + + if algorithm is None or algorithm == "pycosat": + pycosat = import_module('pycosat') + if pycosat is not None: + algorithm = "pycosat" + else: + if algorithm == "pycosat": + raise ImportError("pycosat module is not present") + # Silently fall back to dpll2 if pycosat + # is not installed + algorithm = "dpll2" + + if algorithm=="minisat22": + pysat = import_module('pysat') + if pysat is None: + algorithm = "dpll2" + + if algorithm=="z3": + z3 = import_module('z3') + if z3 is None: + algorithm = "dpll2" + + if algorithm == "dpll": + from sympy.logic.algorithms.dpll import dpll_satisfiable + return dpll_satisfiable(expr) + elif algorithm == "dpll2": + from sympy.logic.algorithms.dpll2 import dpll_satisfiable + return dpll_satisfiable(expr, all_models, use_lra_theory=use_lra_theory) + elif algorithm == "pycosat": + from sympy.logic.algorithms.pycosat_wrapper import pycosat_satisfiable + return pycosat_satisfiable(expr, all_models) + elif algorithm == "minisat22": + from sympy.logic.algorithms.minisat22_wrapper import minisat22_satisfiable + return minisat22_satisfiable(expr, all_models, minimal) + elif algorithm == "z3": + from sympy.logic.algorithms.z3_wrapper import z3_satisfiable + return z3_satisfiable(expr, all_models) + + raise NotImplementedError + + +def valid(expr): + """ + Check validity of a propositional sentence. + A valid propositional sentence is True under every assignment. + + Examples + ======== + + >>> from sympy.abc import A, B + >>> from sympy.logic.inference import valid + >>> valid(A | ~A) + True + >>> valid(A | B) + False + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Validity + + """ + return not satisfiable(Not(expr)) + + +def pl_true(expr, model=None, deep=False): + """ + Returns whether the given assignment is a model or not. + + If the assignment does not specify the value for every proposition, + this may return None to indicate 'not obvious'. + + Parameters + ========== + + model : dict, optional, default: {} + Mapping of symbols to boolean values to indicate assignment. + deep: boolean, optional, default: False + Gives the value of the expression under partial assignments + correctly. May still return None to indicate 'not obvious'. + + + Examples + ======== + + >>> from sympy.abc import A, B + >>> from sympy.logic.inference import pl_true + >>> pl_true( A & B, {A: True, B: True}) + True + >>> pl_true(A & B, {A: False}) + False + >>> pl_true(A & B, {A: True}) + >>> pl_true(A & B, {A: True}, deep=True) + >>> pl_true(A >> (B >> A)) + >>> pl_true(A >> (B >> A), deep=True) + True + >>> pl_true(A & ~A) + >>> pl_true(A & ~A, deep=True) + False + >>> pl_true(A & B & (~A | ~B), {A: True}) + >>> pl_true(A & B & (~A | ~B), {A: True}, deep=True) + False + + """ + + from sympy.core.symbol import Symbol + + boolean = (True, False) + + def _validate(expr): + if isinstance(expr, Symbol) or expr in boolean: + return True + if not isinstance(expr, BooleanFunction): + return False + return all(_validate(arg) for arg in expr.args) + + if expr in boolean: + return expr + expr = sympify(expr) + if not _validate(expr): + raise ValueError("%s is not a valid boolean expression" % expr) + if not model: + model = {} + model = {k: v for k, v in model.items() if v in boolean} + result = expr.subs(model) + if result in boolean: + return bool(result) + if deep: + model = dict.fromkeys(result.atoms(), True) + if pl_true(result, model): + if valid(result): + return True + else: + if not satisfiable(result): + return False + return None + + +def entails(expr, formula_set=None): + """ + Check whether the given expr_set entail an expr. + If formula_set is empty then it returns the validity of expr. + + Examples + ======== + + >>> from sympy.abc import A, B, C + >>> from sympy.logic.inference import entails + >>> entails(A, [A >> B, B >> C]) + False + >>> entails(C, [A >> B, B >> C, A]) + True + >>> entails(A >> B) + False + >>> entails(A >> (B >> A)) + True + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Logical_consequence + + """ + if formula_set: + formula_set = list(formula_set) + else: + formula_set = [] + formula_set.append(Not(expr)) + return not satisfiable(And(*formula_set)) + + +class KB: + """Base class for all knowledge bases""" + def __init__(self, sentence=None): + self.clauses_ = set() + if sentence: + self.tell(sentence) + + def tell(self, sentence): + raise NotImplementedError + + def ask(self, query): + raise NotImplementedError + + def retract(self, sentence): + raise NotImplementedError + + @property + def clauses(self): + return list(ordered(self.clauses_)) + + +class PropKB(KB): + """A KB for Propositional Logic. Inefficient, with no indexing.""" + + def tell(self, sentence): + """Add the sentence's clauses to the KB + + Examples + ======== + + >>> from sympy.logic.inference import PropKB + >>> from sympy.abc import x, y + >>> l = PropKB() + >>> l.clauses + [] + + >>> l.tell(x | y) + >>> l.clauses + [x | y] + + >>> l.tell(y) + >>> l.clauses + [y, x | y] + + """ + for c in conjuncts(to_cnf(sentence)): + self.clauses_.add(c) + + def ask(self, query): + """Checks if the query is true given the set of clauses. + + Examples + ======== + + >>> from sympy.logic.inference import PropKB + >>> from sympy.abc import x, y + >>> l = PropKB() + >>> l.tell(x & ~y) + >>> l.ask(x) + True + >>> l.ask(y) + False + + """ + return entails(query, self.clauses_) + + def retract(self, sentence): + """Remove the sentence's clauses from the KB + + Examples + ======== + + >>> from sympy.logic.inference import PropKB + >>> from sympy.abc import x, y + >>> l = PropKB() + >>> l.clauses + [] + + >>> l.tell(x | y) + >>> l.clauses + [x | y] + + >>> l.retract(x | y) + >>> l.clauses + [] + + """ + for c in conjuncts(to_cnf(sentence)): + self.clauses_.discard(c) diff --git a/.venv/lib/python3.13/site-packages/sympy/multipledispatch/__init__.py b/.venv/lib/python3.13/site-packages/sympy/multipledispatch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5447651645e3e2e92df3002822e87a773ade0df8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/multipledispatch/__init__.py @@ -0,0 +1,11 @@ +from .core import dispatch +from .dispatcher import (Dispatcher, halt_ordering, restart_ordering, + MDNotImplementedError) + +__version__ = '0.4.9' + +__all__ = [ + 'dispatch', + + 'Dispatcher', 'halt_ordering', 'restart_ordering', 'MDNotImplementedError', +] diff --git a/.venv/lib/python3.13/site-packages/sympy/multipledispatch/conflict.py b/.venv/lib/python3.13/site-packages/sympy/multipledispatch/conflict.py new file mode 100644 index 0000000000000000000000000000000000000000..98c6742c9c03860233ef0004b241ea3944ac6d4d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/multipledispatch/conflict.py @@ -0,0 +1,68 @@ +from .utils import _toposort, groupby + +class AmbiguityWarning(Warning): + pass + + +def supercedes(a, b): + """ A is consistent and strictly more specific than B """ + return len(a) == len(b) and all(map(issubclass, a, b)) + + +def consistent(a, b): + """ It is possible for an argument list to satisfy both A and B """ + return (len(a) == len(b) and + all(issubclass(aa, bb) or issubclass(bb, aa) + for aa, bb in zip(a, b))) + + +def ambiguous(a, b): + """ A is consistent with B but neither is strictly more specific """ + return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a)) + + +def ambiguities(signatures): + """ All signature pairs such that A is ambiguous with B """ + signatures = list(map(tuple, signatures)) + return {(a, b) for a in signatures for b in signatures + if hash(a) < hash(b) + and ambiguous(a, b) + and not any(supercedes(c, a) and supercedes(c, b) + for c in signatures)} + + +def super_signature(signatures): + """ A signature that would break ambiguities """ + n = len(signatures[0]) + assert all(len(s) == n for s in signatures) + + return [max([type.mro(sig[i]) for sig in signatures], key=len)[0] + for i in range(n)] + + +def edge(a, b, tie_breaker=hash): + """ A should be checked before B + + Tie broken by tie_breaker, defaults to ``hash`` + """ + if supercedes(a, b): + if supercedes(b, a): + return tie_breaker(a) > tie_breaker(b) + else: + return True + return False + + +def ordering(signatures): + """ A sane ordering of signatures to check, first to last + + Topoological sort of edges as given by ``edge`` and ``supercedes`` + """ + signatures = list(map(tuple, signatures)) + edges = [(a, b) for a in signatures for b in signatures if edge(a, b)] + edges = groupby(lambda x: x[0], edges) + for s in signatures: + if s not in edges: + edges[s] = [] + edges = {k: [b for a, b in v] for k, v in edges.items()} + return _toposort(edges) diff --git a/.venv/lib/python3.13/site-packages/sympy/multipledispatch/core.py b/.venv/lib/python3.13/site-packages/sympy/multipledispatch/core.py new file mode 100644 index 0000000000000000000000000000000000000000..2856ff728c4eb97c5a59fffabddb4bf3c8b4baf2 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/multipledispatch/core.py @@ -0,0 +1,83 @@ +from __future__ import annotations +from typing import Any + +import inspect + +from .dispatcher import Dispatcher, MethodDispatcher, ambiguity_warn + +# XXX: This parameter to dispatch isn't documented and isn't used anywhere in +# sympy. Maybe it should just be removed. +global_namespace: dict[str, Any] = {} + + +def dispatch(*types, namespace=global_namespace, on_ambiguity=ambiguity_warn): + """ Dispatch function on the types of the inputs + + Supports dispatch on all non-keyword arguments. + + Collects implementations based on the function name. Ignores namespaces. + + If ambiguous type signatures occur a warning is raised when the function is + defined suggesting the additional method to break the ambiguity. + + Examples + -------- + + >>> from sympy.multipledispatch import dispatch + >>> @dispatch(int) + ... def f(x): + ... return x + 1 + + >>> @dispatch(float) + ... def f(x): # noqa: F811 + ... return x - 1 + + >>> f(3) + 4 + >>> f(3.0) + 2.0 + + Specify an isolated namespace with the namespace keyword argument + + >>> my_namespace = dict() + >>> @dispatch(int, namespace=my_namespace) + ... def foo(x): + ... return x + 1 + + Dispatch on instance methods within classes + + >>> class MyClass(object): + ... @dispatch(list) + ... def __init__(self, data): + ... self.data = data + ... @dispatch(int) + ... def __init__(self, datum): # noqa: F811 + ... self.data = [datum] + """ + types = tuple(types) + + def _(func): + name = func.__name__ + + if ismethod(func): + dispatcher = inspect.currentframe().f_back.f_locals.get( + name, + MethodDispatcher(name)) + else: + if name not in namespace: + namespace[name] = Dispatcher(name) + dispatcher = namespace[name] + + dispatcher.add(types, func, on_ambiguity=on_ambiguity) + return dispatcher + return _ + + +def ismethod(func): + """ Is func a method? + + Note that this has to work as the method is defined but before the class is + defined. At this stage methods look like functions. + """ + signature = inspect.signature(func) + return signature.parameters.get('self', None) is not None diff --git a/.venv/lib/python3.13/site-packages/sympy/multipledispatch/dispatcher.py b/.venv/lib/python3.13/site-packages/sympy/multipledispatch/dispatcher.py new file mode 100644 index 0000000000000000000000000000000000000000..89471d678e1c330138a91ec6a41a324d29a037d7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/multipledispatch/dispatcher.py @@ -0,0 +1,413 @@ +from __future__ import annotations + +from warnings import warn +import inspect +from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning +from .utils import expand_tuples +import itertools as itl + + +class MDNotImplementedError(NotImplementedError): + """ A NotImplementedError for multiple dispatch """ + + +### Functions for on_ambiguity + +def ambiguity_warn(dispatcher, ambiguities): + """ Raise warning when ambiguity is detected + + Parameters + ---------- + dispatcher : Dispatcher + The dispatcher on which the ambiguity was detected + ambiguities : set + Set of type signature pairs that are ambiguous within this dispatcher + + See Also: + Dispatcher.add + warning_text + """ + warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning) + + +class RaiseNotImplementedError: + """Raise ``NotImplementedError`` when called.""" + + def __init__(self, dispatcher): + self.dispatcher = dispatcher + + def __call__(self, *args, **kwargs): + types = tuple(type(a) for a in args) + raise NotImplementedError( + "Ambiguous signature for %s: <%s>" % ( + self.dispatcher.name, str_signature(types) + )) + +def ambiguity_register_error_ignore_dup(dispatcher, ambiguities): + """ + If super signature for ambiguous types is duplicate types, ignore it. + Else, register instance of ``RaiseNotImplementedError`` for ambiguous types. + + Parameters + ---------- + dispatcher : Dispatcher + The dispatcher on which the ambiguity was detected + ambiguities : set + Set of type signature pairs that are ambiguous within this dispatcher + + See Also: + Dispatcher.add + ambiguity_warn + """ + for amb in ambiguities: + signature = tuple(super_signature(amb)) + if len(set(signature)) == 1: + continue + dispatcher.add( + signature, RaiseNotImplementedError(dispatcher), + on_ambiguity=ambiguity_register_error_ignore_dup + ) + +### + + +_unresolved_dispatchers: set[Dispatcher] = set() +_resolve = [True] + + +def halt_ordering(): + _resolve[0] = False + + +def restart_ordering(on_ambiguity=ambiguity_warn): + _resolve[0] = True + while _unresolved_dispatchers: + dispatcher = _unresolved_dispatchers.pop() + dispatcher.reorder(on_ambiguity=on_ambiguity) + + +class Dispatcher: + """ Dispatch methods based on type signature + + Use ``dispatch`` to add implementations + + Examples + -------- + + >>> from sympy.multipledispatch import dispatch + >>> @dispatch(int) + ... def f(x): + ... return x + 1 + + >>> @dispatch(float) + ... def f(x): # noqa: F811 + ... return x - 1 + + >>> f(3) + 4 + >>> f(3.0) + 2.0 + """ + __slots__ = '__name__', 'name', 'funcs', 'ordering', '_cache', 'doc' + + def __init__(self, name, doc=None): + self.name = self.__name__ = name + self.funcs = {} + self._cache = {} + self.ordering = [] + self.doc = doc + + def register(self, *types, **kwargs): + """ Register dispatcher with new implementation + + >>> from sympy.multipledispatch.dispatcher import Dispatcher + >>> f = Dispatcher('f') + >>> @f.register(int) + ... def inc(x): + ... return x + 1 + + >>> @f.register(float) + ... def dec(x): + ... return x - 1 + + >>> @f.register(list) + ... @f.register(tuple) + ... def reverse(x): + ... return x[::-1] + + >>> f(1) + 2 + + >>> f(1.0) + 0.0 + + >>> f([1, 2, 3]) + [3, 2, 1] + """ + def _(func): + self.add(types, func, **kwargs) + return func + return _ + + @classmethod + def get_func_params(cls, func): + if hasattr(inspect, "signature"): + sig = inspect.signature(func) + return sig.parameters.values() + + @classmethod + def get_func_annotations(cls, func): + """ Get annotations of function positional parameters + """ + params = cls.get_func_params(func) + if params: + Parameter = inspect.Parameter + + params = (param for param in params + if param.kind in + (Parameter.POSITIONAL_ONLY, + Parameter.POSITIONAL_OR_KEYWORD)) + + annotations = tuple( + param.annotation + for param in params) + + if not any(ann is Parameter.empty for ann in annotations): + return annotations + + def add(self, signature, func, on_ambiguity=ambiguity_warn): + """ Add new types/method pair to dispatcher + + >>> from sympy.multipledispatch import Dispatcher + >>> D = Dispatcher('add') + >>> D.add((int, int), lambda x, y: x + y) + >>> D.add((float, float), lambda x, y: x + y) + + >>> D(1, 2) + 3 + >>> D(1, 2.0) + Traceback (most recent call last): + ... + NotImplementedError: Could not find signature for add: + + When ``add`` detects a warning it calls the ``on_ambiguity`` callback + with a dispatcher/itself, and a set of ambiguous type signature pairs + as inputs. See ``ambiguity_warn`` for an example. + """ + # Handle annotations + if not signature: + annotations = self.get_func_annotations(func) + if annotations: + signature = annotations + + # Handle union types + if any(isinstance(typ, tuple) for typ in signature): + for typs in expand_tuples(signature): + self.add(typs, func, on_ambiguity) + return + + for typ in signature: + if not isinstance(typ, type): + str_sig = ', '.join(c.__name__ if isinstance(c, type) + else str(c) for c in signature) + raise TypeError("Tried to dispatch on non-type: %s\n" + "In signature: <%s>\n" + "In function: %s" % + (typ, str_sig, self.name)) + + self.funcs[signature] = func + self.reorder(on_ambiguity=on_ambiguity) + self._cache.clear() + + def reorder(self, on_ambiguity=ambiguity_warn): + if _resolve[0]: + self.ordering = ordering(self.funcs) + amb = ambiguities(self.funcs) + if amb: + on_ambiguity(self, amb) + else: + _unresolved_dispatchers.add(self) + + def __call__(self, *args, **kwargs): + types = tuple([type(arg) for arg in args]) + try: + func = self._cache[types] + except KeyError: + func = self.dispatch(*types) + if not func: + raise NotImplementedError( + 'Could not find signature for %s: <%s>' % + (self.name, str_signature(types))) + self._cache[types] = func + try: + return func(*args, **kwargs) + + except MDNotImplementedError: + funcs = self.dispatch_iter(*types) + next(funcs) # burn first + for func in funcs: + try: + return func(*args, **kwargs) + except MDNotImplementedError: + pass + raise NotImplementedError("Matching functions for " + "%s: <%s> found, but none completed successfully" + % (self.name, str_signature(types))) + + def __str__(self): + return "" % self.name + __repr__ = __str__ + + def dispatch(self, *types): + """ Deterimine appropriate implementation for this type signature + + This method is internal. Users should call this object as a function. + Implementation resolution occurs within the ``__call__`` method. + + >>> from sympy.multipledispatch import dispatch + >>> @dispatch(int) + ... def inc(x): + ... return x + 1 + + >>> implementation = inc.dispatch(int) + >>> implementation(3) + 4 + + >>> print(inc.dispatch(float)) + None + + See Also: + ``sympy.multipledispatch.conflict`` - module to determine resolution order + """ + + if types in self.funcs: + return self.funcs[types] + + try: + return next(self.dispatch_iter(*types)) + except StopIteration: + return None + + def dispatch_iter(self, *types): + n = len(types) + for signature in self.ordering: + if len(signature) == n and all(map(issubclass, types, signature)): + result = self.funcs[signature] + yield result + + def resolve(self, types): + """ Deterimine appropriate implementation for this type signature + + .. deprecated:: 0.4.4 + Use ``dispatch(*types)`` instead + """ + warn("resolve() is deprecated, use dispatch(*types)", + DeprecationWarning) + + return self.dispatch(*types) + + def __getstate__(self): + return {'name': self.name, + 'funcs': self.funcs} + + def __setstate__(self, d): + self.name = d['name'] + self.funcs = d['funcs'] + self.ordering = ordering(self.funcs) + self._cache = {} + + @property + def __doc__(self): + docs = ["Multiply dispatched method: %s" % self.name] + + if self.doc: + docs.append(self.doc) + + other = [] + for sig in self.ordering[::-1]: + func = self.funcs[sig] + if func.__doc__: + s = 'Inputs: <%s>\n' % str_signature(sig) + s += '-' * len(s) + '\n' + s += func.__doc__.strip() + docs.append(s) + else: + other.append(str_signature(sig)) + + if other: + docs.append('Other signatures:\n ' + '\n '.join(other)) + + return '\n\n'.join(docs) + + def _help(self, *args): + return self.dispatch(*map(type, args)).__doc__ + + def help(self, *args, **kwargs): + """ Print docstring for the function corresponding to inputs """ + print(self._help(*args)) + + def _source(self, *args): + func = self.dispatch(*map(type, args)) + if not func: + raise TypeError("No function found") + return source(func) + + def source(self, *args, **kwargs): + """ Print source code for the function corresponding to inputs """ + print(self._source(*args)) + + +def source(func): + s = 'File: %s\n\n' % inspect.getsourcefile(func) + s = s + inspect.getsource(func) + return s + + +class MethodDispatcher(Dispatcher): + """ Dispatch methods based on type signature + + See Also: + Dispatcher + """ + + @classmethod + def get_func_params(cls, func): + if hasattr(inspect, "signature"): + sig = inspect.signature(func) + return itl.islice(sig.parameters.values(), 1, None) + + def __get__(self, instance, owner): + self.obj = instance + self.cls = owner + return self + + def __call__(self, *args, **kwargs): + types = tuple([type(arg) for arg in args]) + func = self.dispatch(*types) + if not func: + raise NotImplementedError('Could not find signature for %s: <%s>' % + (self.name, str_signature(types))) + return func(self.obj, *args, **kwargs) + + +def str_signature(sig): + """ String representation of type signature + + >>> from sympy.multipledispatch.dispatcher import str_signature + >>> str_signature((int, float)) + 'int, float' + """ + return ', '.join(cls.__name__ for cls in sig) + + +def warning_text(name, amb): + """ The text for ambiguity warnings """ + text = "\nAmbiguities exist in dispatched function %s\n\n" % (name) + text += "The following signatures may result in ambiguous behavior:\n" + for pair in amb: + text += "\t" + \ + ', '.join('[' + str_signature(s) + ']' for s in pair) + "\n" + text += "\n\nConsider making the following additions:\n\n" + text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s)) + + ')\ndef %s(...)' % name for s in amb]) + return text diff --git a/.venv/lib/python3.13/site-packages/sympy/multipledispatch/utils.py b/.venv/lib/python3.13/site-packages/sympy/multipledispatch/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..11f563772385124c2fc0d285f7aa6e0747b8b412 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/multipledispatch/utils.py @@ -0,0 +1,105 @@ +from collections import OrderedDict + + +def expand_tuples(L): + """ + >>> from sympy.multipledispatch.utils import expand_tuples + >>> expand_tuples([1, (2, 3)]) + [(1, 2), (1, 3)] + + >>> expand_tuples([1, 2]) + [(1, 2)] + """ + if not L: + return [()] + elif not isinstance(L[0], tuple): + rest = expand_tuples(L[1:]) + return [(L[0],) + t for t in rest] + else: + rest = expand_tuples(L[1:]) + return [(item,) + t for t in rest for item in L[0]] + + +# Taken from theano/theano/gof/sched.py +# Avoids licensing issues because this was written by Matthew Rocklin +def _toposort(edges): + """ Topological sort algorithm by Kahn [1] - O(nodes + vertices) + + inputs: + edges - a dict of the form {a: {b, c}} where b and c depend on a + outputs: + L - an ordered list of nodes that satisfy the dependencies of edges + + >>> from sympy.multipledispatch.utils import _toposort + >>> _toposort({1: (2, 3), 2: (3, )}) + [1, 2, 3] + + Closely follows the wikipedia page [2] + + [1] Kahn, Arthur B. (1962), "Topological sorting of large networks", + Communications of the ACM + [2] https://en.wikipedia.org/wiki/Toposort#Algorithms + """ + incoming_edges = reverse_dict(edges) + incoming_edges = {k: set(val) for k, val in incoming_edges.items()} + S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges) + L = [] + + while S: + n, _ = S.popitem() + L.append(n) + for m in edges.get(n, ()): + assert n in incoming_edges[m] + incoming_edges[m].remove(n) + if not incoming_edges[m]: + S[m] = None + if any(incoming_edges.get(v, None) for v in edges): + raise ValueError("Input has cycles") + return L + + +def reverse_dict(d): + """Reverses direction of dependence dict + + >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()} + >>> reverse_dict(d) # doctest: +SKIP + {1: ('a',), 2: ('a', 'b'), 3: ('b',)} + + :note: dict order are not deterministic. As we iterate on the + input dict, it make the output of this function depend on the + dict order. So this function output order should be considered + as undeterministic. + + """ + result = {} + for key in d: + for val in d[key]: + result[val] = result.get(val, ()) + (key, ) + return result + + +# Taken from toolz +# Avoids licensing issues because this version was authored by Matthew Rocklin +def groupby(func, seq): + """ Group a collection by a key function + + >>> from sympy.multipledispatch.utils import groupby + >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank'] + >>> groupby(len, names) # doctest: +SKIP + {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']} + + >>> iseven = lambda x: x % 2 == 0 + >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP + {False: [1, 3, 5, 7], True: [2, 4, 6, 8]} + + See Also: + ``countby`` + """ + + d = {} + for item in seq: + key = func(item) + if key not in d: + d[key] = [] + d[key].append(item) + return d diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/__init__.py b/.venv/lib/python3.13/site-packages/sympy/plotting/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..074bcf93b7375eb3dc96d16b5450b539074d8f7d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/__init__.py @@ -0,0 +1,22 @@ +from .plot import plot_backends +from .plot_implicit import plot_implicit +from .textplot import textplot +from .pygletplot import PygletPlot +from .plot import PlotGrid +from .plot import (plot, plot_parametric, plot3d, plot3d_parametric_surface, + plot3d_parametric_line, plot_contour) + +__all__ = [ + 'plot_backends', + + 'plot_implicit', + + 'textplot', + + 'PygletPlot', + + 'PlotGrid', + + 'plot', 'plot_parametric', 'plot3d', 'plot3d_parametric_surface', + 'plot3d_parametric_line', 'plot_contour' +] diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/experimental_lambdify.py b/.venv/lib/python3.13/site-packages/sympy/plotting/experimental_lambdify.py new file mode 100644 index 0000000000000000000000000000000000000000..ae17e7adf45f2933ccd71514917199c85d14549e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/experimental_lambdify.py @@ -0,0 +1,641 @@ +""" rewrite of lambdify - This stuff is not stable at all. + +It is for internal use in the new plotting module. +It may (will! see the Q'n'A in the source) be rewritten. + +It's completely self contained. Especially it does not use lambdarepr. + +It does not aim to replace the current lambdify. Most importantly it will never +ever support anything else than SymPy expressions (no Matrices, dictionaries +and so on). +""" + + +import re +from sympy.core.numbers import (I, NumberSymbol, oo, zoo) +from sympy.core.symbol import Symbol +from sympy.utilities.iterables import numbered_symbols + +# We parse the expression string into a tree that identifies functions. Then +# we translate the names of the functions and we translate also some strings +# that are not names of functions (all this according to translation +# dictionaries). +# If the translation goes to another module (like numpy) the +# module is imported and 'func' is translated to 'module.func'. +# If a function can not be translated, the inner nodes of that part of the +# tree are not translated. So if we have Integral(sqrt(x)), sqrt is not +# translated to np.sqrt and the Integral does not crash. +# A namespace for all this is generated by crawling the (func, args) tree of +# the expression. The creation of this namespace involves many ugly +# workarounds. +# The namespace consists of all the names needed for the SymPy expression and +# all the name of modules used for translation. Those modules are imported only +# as a name (import numpy as np) in order to keep the namespace small and +# manageable. + +# Please, if there is a bug, do not try to fix it here! Rewrite this by using +# the method proposed in the last Q'n'A below. That way the new function will +# work just as well, be just as simple, but it wont need any new workarounds. +# If you insist on fixing it here, look at the workarounds in the function +# sympy_expression_namespace and in lambdify. + +# Q: Why are you not using Python abstract syntax tree? +# A: Because it is more complicated and not much more powerful in this case. + +# Q: What if I have Symbol('sin') or g=Function('f')? +# A: You will break the algorithm. We should use srepr to defend against this? +# The problem with Symbol('sin') is that it will be printed as 'sin'. The +# parser will distinguish it from the function 'sin' because functions are +# detected thanks to the opening parenthesis, but the lambda expression won't +# understand the difference if we have also the sin function. +# The solution (complicated) is to use srepr and maybe ast. +# The problem with the g=Function('f') is that it will be printed as 'f' but in +# the global namespace we have only 'g'. But as the same printer is used in the +# constructor of the namespace there will be no problem. + +# Q: What if some of the printers are not printing as expected? +# A: The algorithm wont work. You must use srepr for those cases. But even +# srepr may not print well. All problems with printers should be considered +# bugs. + +# Q: What about _imp_ functions? +# A: Those are taken care for by evalf. A special case treatment will work +# faster but it's not worth the code complexity. + +# Q: Will ast fix all possible problems? +# A: No. You will always have to use some printer. Even srepr may not work in +# some cases. But if the printer does not work, that should be considered a +# bug. + +# Q: Is there same way to fix all possible problems? +# A: Probably by constructing our strings ourself by traversing the (func, +# args) tree and creating the namespace at the same time. That actually sounds +# good. + +from sympy.external import import_module +import warnings + +#TODO debugging output + + +class vectorized_lambdify: + """ Return a sufficiently smart, vectorized and lambdified function. + + Returns only reals. + + Explanation + =========== + + This function uses experimental_lambdify to created a lambdified + expression ready to be used with numpy. Many of the functions in SymPy + are not implemented in numpy so in some cases we resort to Python cmath or + even to evalf. + + The following translations are tried: + only numpy complex + - on errors raised by SymPy trying to work with ndarray: + only Python cmath and then vectorize complex128 + + When using Python cmath there is no need for evalf or float/complex + because Python cmath calls those. + + This function never tries to mix numpy directly with evalf because numpy + does not understand SymPy Float. If this is needed one can use the + float_wrap_evalf/complex_wrap_evalf options of experimental_lambdify or + better one can be explicit about the dtypes that numpy works with. + Check numpy bug http://projects.scipy.org/numpy/ticket/1013 to know what + types of errors to expect. + """ + def __init__(self, args, expr): + self.args = args + self.expr = expr + self.np = import_module('numpy') + + self.lambda_func_1 = experimental_lambdify( + args, expr, use_np=True) + self.vector_func_1 = self.lambda_func_1 + + self.lambda_func_2 = experimental_lambdify( + args, expr, use_python_cmath=True) + self.vector_func_2 = self.np.vectorize( + self.lambda_func_2, otypes=[complex]) + + self.vector_func = self.vector_func_1 + self.failure = False + + def __call__(self, *args): + np = self.np + + try: + temp_args = (np.array(a, dtype=complex) for a in args) + results = self.vector_func(*temp_args) + results = np.ma.masked_where( + np.abs(results.imag) > 1e-7 * np.abs(results), + results.real, copy=False) + return results + except ValueError: + if self.failure: + raise + + self.failure = True + self.vector_func = self.vector_func_2 + warnings.warn( + 'The evaluation of the expression is problematic. ' + 'We are trying a failback method that may still work. ' + 'Please report this as a bug.') + return self.__call__(*args) + + +class lambdify: + """Returns the lambdified function. + + Explanation + =========== + + This function uses experimental_lambdify to create a lambdified + expression. It uses cmath to lambdify the expression. If the function + is not implemented in Python cmath, Python cmath calls evalf on those + functions. + """ + + def __init__(self, args, expr): + self.args = args + self.expr = expr + self.lambda_func_1 = experimental_lambdify( + args, expr, use_python_cmath=True, use_evalf=True) + self.lambda_func_2 = experimental_lambdify( + args, expr, use_python_math=True, use_evalf=True) + self.lambda_func_3 = experimental_lambdify( + args, expr, use_evalf=True, complex_wrap_evalf=True) + self.lambda_func = self.lambda_func_1 + self.failure = False + + def __call__(self, args): + try: + #The result can be sympy.Float. Hence wrap it with complex type. + result = complex(self.lambda_func(args)) + if abs(result.imag) > 1e-7 * abs(result): + return None + return result.real + except (ZeroDivisionError, OverflowError): + return None + except TypeError as e: + if self.failure: + raise e + + if self.lambda_func == self.lambda_func_1: + self.lambda_func = self.lambda_func_2 + return self.__call__(args) + + self.failure = True + self.lambda_func = self.lambda_func_3 + warnings.warn( + 'The evaluation of the expression is problematic. ' + 'We are trying a failback method that may still work. ' + 'Please report this as a bug.', stacklevel=2) + return self.__call__(args) + + +def experimental_lambdify(*args, **kwargs): + l = Lambdifier(*args, **kwargs) + return l + + +class Lambdifier: + def __init__(self, args, expr, print_lambda=False, use_evalf=False, + float_wrap_evalf=False, complex_wrap_evalf=False, + use_np=False, use_python_math=False, use_python_cmath=False, + use_interval=False): + + self.print_lambda = print_lambda + self.use_evalf = use_evalf + self.float_wrap_evalf = float_wrap_evalf + self.complex_wrap_evalf = complex_wrap_evalf + self.use_np = use_np + self.use_python_math = use_python_math + self.use_python_cmath = use_python_cmath + self.use_interval = use_interval + + # Constructing the argument string + # - check + if not all(isinstance(a, Symbol) for a in args): + raise ValueError('The arguments must be Symbols.') + # - use numbered symbols + syms = numbered_symbols(exclude=expr.free_symbols) + newargs = [next(syms) for _ in args] + expr = expr.xreplace(dict(zip(args, newargs))) + argstr = ', '.join([str(a) for a in newargs]) + del syms, newargs, args + + # Constructing the translation dictionaries and making the translation + self.dict_str = self.get_dict_str() + self.dict_fun = self.get_dict_fun() + exprstr = str(expr) + newexpr = self.tree2str_translate(self.str2tree(exprstr)) + + # Constructing the namespaces + namespace = {} + namespace.update(self.sympy_atoms_namespace(expr)) + namespace.update(self.sympy_expression_namespace(expr)) + # XXX Workaround + # Ugly workaround because Pow(a,Half) prints as sqrt(a) + # and sympy_expression_namespace can not catch it. + from sympy.functions.elementary.miscellaneous import sqrt + namespace.update({'sqrt': sqrt}) + namespace.update({'Eq': lambda x, y: x == y}) + namespace.update({'Ne': lambda x, y: x != y}) + # End workaround. + if use_python_math: + namespace.update({'math': __import__('math')}) + if use_python_cmath: + namespace.update({'cmath': __import__('cmath')}) + if use_np: + try: + namespace.update({'np': __import__('numpy')}) + except ImportError: + raise ImportError( + 'experimental_lambdify failed to import numpy.') + if use_interval: + namespace.update({'imath': __import__( + 'sympy.plotting.intervalmath', fromlist=['intervalmath'])}) + namespace.update({'math': __import__('math')}) + + # Construct the lambda + if self.print_lambda: + print(newexpr) + eval_str = 'lambda %s : ( %s )' % (argstr, newexpr) + self.eval_str = eval_str + exec("MYNEWLAMBDA = %s" % eval_str, namespace) + self.lambda_func = namespace['MYNEWLAMBDA'] + + def __call__(self, *args, **kwargs): + return self.lambda_func(*args, **kwargs) + + + ############################################################################## + # Dicts for translating from SymPy to other modules + ############################################################################## + ### + # builtins + ### + # Functions with different names in builtins + builtin_functions_different = { + 'Min': 'min', + 'Max': 'max', + 'Abs': 'abs', + } + + # Strings that should be translated + builtin_not_functions = { + 'I': '1j', +# 'oo': '1e400', + } + + ### + # numpy + ### + + # Functions that are the same in numpy + numpy_functions_same = [ + 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'exp', 'log', + 'sqrt', 'floor', 'conjugate', 'sign', + ] + + # Functions with different names in numpy + numpy_functions_different = { + "acos": "arccos", + "acosh": "arccosh", + "arg": "angle", + "asin": "arcsin", + "asinh": "arcsinh", + "atan": "arctan", + "atan2": "arctan2", + "atanh": "arctanh", + "ceiling": "ceil", + "im": "imag", + "ln": "log", + "Max": "amax", + "Min": "amin", + "re": "real", + "Abs": "abs", + } + + # Strings that should be translated + numpy_not_functions = { + 'pi': 'np.pi', + 'oo': 'np.inf', + 'E': 'np.e', + } + + ### + # Python math + ### + + # Functions that are the same in math + math_functions_same = [ + 'sin', 'cos', 'tan', 'asin', 'acos', 'atan', 'atan2', + 'sinh', 'cosh', 'tanh', 'asinh', 'acosh', 'atanh', + 'exp', 'log', 'erf', 'sqrt', 'floor', 'factorial', 'gamma', + ] + + # Functions with different names in math + math_functions_different = { + 'ceiling': 'ceil', + 'ln': 'log', + 'loggamma': 'lgamma' + } + + # Strings that should be translated + math_not_functions = { + 'pi': 'math.pi', + 'E': 'math.e', + } + + ### + # Python cmath + ### + + # Functions that are the same in cmath + cmath_functions_same = [ + 'sin', 'cos', 'tan', 'asin', 'acos', 'atan', + 'sinh', 'cosh', 'tanh', 'asinh', 'acosh', 'atanh', + 'exp', 'log', 'sqrt', + ] + + # Functions with different names in cmath + cmath_functions_different = { + 'ln': 'log', + 'arg': 'phase', + } + + # Strings that should be translated + cmath_not_functions = { + 'pi': 'cmath.pi', + 'E': 'cmath.e', + } + + ### + # intervalmath + ### + + interval_not_functions = { + 'pi': 'math.pi', + 'E': 'math.e' + } + + interval_functions_same = [ + 'sin', 'cos', 'exp', 'tan', 'atan', 'log', + 'sqrt', 'cosh', 'sinh', 'tanh', 'floor', + 'acos', 'asin', 'acosh', 'asinh', 'atanh', + 'Abs', 'And', 'Or' + ] + + interval_functions_different = { + 'Min': 'imin', + 'Max': 'imax', + 'ceiling': 'ceil', + + } + + ### + # mpmath, etc + ### + #TODO + + ### + # Create the final ordered tuples of dictionaries + ### + + # For strings + def get_dict_str(self): + dict_str = dict(self.builtin_not_functions) + if self.use_np: + dict_str.update(self.numpy_not_functions) + if self.use_python_math: + dict_str.update(self.math_not_functions) + if self.use_python_cmath: + dict_str.update(self.cmath_not_functions) + if self.use_interval: + dict_str.update(self.interval_not_functions) + return dict_str + + # For functions + def get_dict_fun(self): + dict_fun = dict(self.builtin_functions_different) + if self.use_np: + for s in self.numpy_functions_same: + dict_fun[s] = 'np.' + s + for k, v in self.numpy_functions_different.items(): + dict_fun[k] = 'np.' + v + if self.use_python_math: + for s in self.math_functions_same: + dict_fun[s] = 'math.' + s + for k, v in self.math_functions_different.items(): + dict_fun[k] = 'math.' + v + if self.use_python_cmath: + for s in self.cmath_functions_same: + dict_fun[s] = 'cmath.' + s + for k, v in self.cmath_functions_different.items(): + dict_fun[k] = 'cmath.' + v + if self.use_interval: + for s in self.interval_functions_same: + dict_fun[s] = 'imath.' + s + for k, v in self.interval_functions_different.items(): + dict_fun[k] = 'imath.' + v + return dict_fun + + ############################################################################## + # The translator functions, tree parsers, etc. + ############################################################################## + + def str2tree(self, exprstr): + """Converts an expression string to a tree. + + Explanation + =========== + + Functions are represented by ('func_name(', tree_of_arguments). + Other expressions are (head_string, mid_tree, tail_str). + Expressions that do not contain functions are directly returned. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy import Integral, sin + >>> from sympy.plotting.experimental_lambdify import Lambdifier + >>> str2tree = Lambdifier([x], x).str2tree + + >>> str2tree(str(Integral(x, (x, 1, y)))) + ('', ('Integral(', 'x, (x, 1, y)'), ')') + >>> str2tree(str(x+y)) + 'x + y' + >>> str2tree(str(x+y*sin(z)+1)) + ('x + y*', ('sin(', 'z'), ') + 1') + >>> str2tree('sin(y*(y + 1.1) + (sin(y)))') + ('', ('sin(', ('y*(y + 1.1) + (', ('sin(', 'y'), '))')), ')') + """ + #matches the first 'function_name(' + first_par = re.search(r'(\w+\()', exprstr) + if first_par is None: + return exprstr + else: + start = first_par.start() + end = first_par.end() + head = exprstr[:start] + func = exprstr[start:end] + tail = exprstr[end:] + count = 0 + for i, c in enumerate(tail): + if c == '(': + count += 1 + elif c == ')': + count -= 1 + if count == -1: + break + func_tail = self.str2tree(tail[:i]) + tail = self.str2tree(tail[i:]) + return (head, (func, func_tail), tail) + + @classmethod + def tree2str(cls, tree): + """Converts a tree to string without translations. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy import sin + >>> from sympy.plotting.experimental_lambdify import Lambdifier + >>> str2tree = Lambdifier([x], x).str2tree + >>> tree2str = Lambdifier([x], x).tree2str + + >>> tree2str(str2tree(str(x+y*sin(z)+1))) + 'x + y*sin(z) + 1' + """ + if isinstance(tree, str): + return tree + else: + return ''.join(map(cls.tree2str, tree)) + + def tree2str_translate(self, tree): + """Converts a tree to string with translations. + + Explanation + =========== + + Function names are translated by translate_func. + Other strings are translated by translate_str. + """ + if isinstance(tree, str): + return self.translate_str(tree) + elif isinstance(tree, tuple) and len(tree) == 2: + return self.translate_func(tree[0][:-1], tree[1]) + else: + return ''.join([self.tree2str_translate(t) for t in tree]) + + def translate_str(self, estr): + """Translate substrings of estr using in order the dictionaries in + dict_tuple_str.""" + for pattern, repl in self.dict_str.items(): + estr = re.sub(pattern, repl, estr) + return estr + + def translate_func(self, func_name, argtree): + """Translate function names and the tree of arguments. + + Explanation + =========== + + If the function name is not in the dictionaries of dict_tuple_fun then the + function is surrounded by a float((...).evalf()). + + The use of float is necessary as np.(sympy.Float(..)) raises an + error.""" + if func_name in self.dict_fun: + new_name = self.dict_fun[func_name] + argstr = self.tree2str_translate(argtree) + return new_name + '(' + argstr + elif func_name in ['Eq', 'Ne']: + op = {'Eq': '==', 'Ne': '!='} + return "(lambda x, y: x {} y)({}".format(op[func_name], self.tree2str_translate(argtree)) + else: + template = '(%s(%s)).evalf(' if self.use_evalf else '%s(%s' + if self.float_wrap_evalf: + template = 'float(%s)' % template + elif self.complex_wrap_evalf: + template = 'complex(%s)' % template + + # Wrapping should only happen on the outermost expression, which + # is the only thing we know will be a number. + float_wrap_evalf = self.float_wrap_evalf + complex_wrap_evalf = self.complex_wrap_evalf + self.float_wrap_evalf = False + self.complex_wrap_evalf = False + ret = template % (func_name, self.tree2str_translate(argtree)) + self.float_wrap_evalf = float_wrap_evalf + self.complex_wrap_evalf = complex_wrap_evalf + return ret + + ############################################################################## + # The namespace constructors + ############################################################################## + + @classmethod + def sympy_expression_namespace(cls, expr): + """Traverses the (func, args) tree of an expression and creates a SymPy + namespace. All other modules are imported only as a module name. That way + the namespace is not polluted and rests quite small. It probably causes much + more variable lookups and so it takes more time, but there are no tests on + that for the moment.""" + if expr is None: + return {} + else: + funcname = str(expr.func) + # XXX Workaround + # Here we add an ugly workaround because str(func(x)) + # is not always the same as str(func). Eg + # >>> str(Integral(x)) + # "Integral(x)" + # >>> str(Integral) + # "" + # >>> str(sqrt(x)) + # "sqrt(x)" + # >>> str(sqrt) + # "" + # >>> str(sin(x)) + # "sin(x)" + # >>> str(sin) + # "sin" + # Either one of those can be used but not all at the same time. + # The code considers the sin example as the right one. + regexlist = [ + r'$', + # the example Integral + r'$', # the example sqrt + ] + for r in regexlist: + m = re.match(r, funcname) + if m is not None: + funcname = m.groups()[0] + # End of the workaround + # XXX debug: print funcname + args_dict = {} + for a in expr.args: + if (isinstance(a, (Symbol, NumberSymbol)) or a in [I, zoo, oo]): + continue + else: + args_dict.update(cls.sympy_expression_namespace(a)) + args_dict.update({funcname: expr.func}) + return args_dict + + @staticmethod + def sympy_atoms_namespace(expr): + """For no real reason this function is separated from + sympy_expression_namespace. It can be moved to it.""" + atoms = expr.atoms(Symbol, NumberSymbol, I, zoo, oo) + d = {} + for a in atoms: + # XXX debug: print 'atom:' + str(a) + d[str(a)] = a + return d diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/plot.py b/.venv/lib/python3.13/site-packages/sympy/plotting/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..50029392a1ac70491f93f28c4d443da15e7fc31e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/plot.py @@ -0,0 +1,1234 @@ +"""Plotting module for SymPy. + +A plot is represented by the ``Plot`` class that contains a reference to the +backend and a list of the data series to be plotted. The data series are +instances of classes meant to simplify getting points and meshes from SymPy +expressions. ``plot_backends`` is a dictionary with all the backends. + +This module gives only the essential. For all the fancy stuff use directly +the backend. You can get the backend wrapper for every plot from the +``_backend`` attribute. Moreover the data series classes have various useful +methods like ``get_points``, ``get_meshes``, etc, that may +be useful if you wish to use another plotting library. + +Especially if you need publication ready graphs and this module is not enough +for you - just get the ``_backend`` attribute and add whatever you want +directly to it. In the case of matplotlib (the common way to graph data in +python) just copy ``_backend.fig`` which is the figure and ``_backend.ax`` +which is the axis and work on them as you would on any other matplotlib object. + +Simplicity of code takes much greater importance than performance. Do not use it +if you care at all about performance. A new backend instance is initialized +every time you call ``show()`` and the old one is left to the garbage collector. +""" + +from sympy.concrete.summations import Sum +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.function import Function, AppliedUndef +from sympy.core.symbol import (Dummy, Symbol, Wild) +from sympy.external import import_module +from sympy.functions import sign +from sympy.plotting.backends.base_backend import Plot +from sympy.plotting.backends.matplotlibbackend import MatplotlibBackend +from sympy.plotting.backends.textbackend import TextBackend +from sympy.plotting.series import ( + LineOver1DRangeSeries, Parametric2DLineSeries, Parametric3DLineSeries, + ParametricSurfaceSeries, SurfaceOver2DRangeSeries, ContourSeries) +from sympy.plotting.utils import _check_arguments, _plot_sympify +from sympy.tensor.indexed import Indexed +# to maintain back-compatibility +from sympy.plotting.plotgrid import PlotGrid # noqa: F401 +from sympy.plotting.series import BaseSeries # noqa: F401 +from sympy.plotting.series import Line2DBaseSeries # noqa: F401 +from sympy.plotting.series import Line3DBaseSeries # noqa: F401 +from sympy.plotting.series import SurfaceBaseSeries # noqa: F401 +from sympy.plotting.series import List2DSeries # noqa: F401 +from sympy.plotting.series import GenericDataSeries # noqa: F401 +from sympy.plotting.series import centers_of_faces # noqa: F401 +from sympy.plotting.series import centers_of_segments # noqa: F401 +from sympy.plotting.series import flat # noqa: F401 +from sympy.plotting.backends.base_backend import unset_show # noqa: F401 +from sympy.plotting.backends.matplotlibbackend import _matplotlib_list # noqa: F401 +from sympy.plotting.textplot import textplot # noqa: F401 + + +__doctest_requires__ = { + ('plot3d', + 'plot3d_parametric_line', + 'plot3d_parametric_surface', + 'plot_parametric'): ['matplotlib'], + # XXX: The plot doctest possibly should not require matplotlib. It fails at + # plot(x**2, (x, -5, 5)) which should be fine for text backend. + ('plot',): ['matplotlib'], +} + + +def _process_summations(sum_bound, *args): + """Substitute oo (infinity) in the lower/upper bounds of a summation with + some integer number. + + Parameters + ========== + + sum_bound : int + oo will be substituted with this integer number. + *args : list/tuple + pre-processed arguments of the form (expr, range, ...) + + Notes + ===== + Let's consider the following summation: ``Sum(1 / x**2, (x, 1, oo))``. + The current implementation of lambdify (SymPy 1.12 at the time of + writing this) will create something of this form: + ``sum(1 / x**2 for x in range(1, INF))`` + The problem is that ``type(INF)`` is float, while ``range`` requires + integers: the evaluation fails. + Instead of modifying ``lambdify`` (which requires a deep knowledge), just + replace it with some integer number. + """ + def new_bound(t, bound): + if (not t.is_number) or t.is_finite: + return t + if sign(t) >= 0: + return bound + return -bound + + args = list(args) + expr = args[0] + + # select summations whose lower/upper bound is infinity + w = Wild("w", properties=[ + lambda t: isinstance(t, Sum), + lambda t: any((not a[1].is_finite) or (not a[2].is_finite) for i, a in enumerate(t.args) if i > 0) + ]) + + for t in list(expr.find(w)): + sums_args = list(t.args) + for i, a in enumerate(sums_args): + if i > 0: + sums_args[i] = (a[0], new_bound(a[1], sum_bound), + new_bound(a[2], sum_bound)) + s = Sum(*sums_args) + expr = expr.subs(t, s) + args[0] = expr + return args + + +def _build_line_series(*args, **kwargs): + """Loop over the provided arguments and create the necessary line series. + """ + series = [] + sum_bound = int(kwargs.get("sum_bound", 1000)) + for arg in args: + expr, r, label, rendering_kw = arg + kw = kwargs.copy() + if rendering_kw is not None: + kw["rendering_kw"] = rendering_kw + # TODO: _process_piecewise check goes here + if not callable(expr): + arg = _process_summations(sum_bound, *arg) + series.append(LineOver1DRangeSeries(*arg[:-1], **kw)) + return series + + +def _create_series(series_type, plot_expr, **kwargs): + """Extract the rendering_kw dictionary from the provided arguments and + create an appropriate data series. + """ + series = [] + for args in plot_expr: + kw = kwargs.copy() + if args[-1] is not None: + kw["rendering_kw"] = args[-1] + series.append(series_type(*args[:-1], **kw)) + return series + + +def _set_labels(series, labels, rendering_kw): + """Apply the `label` and `rendering_kw` keyword arguments to the series. + """ + if not isinstance(labels, (list, tuple)): + labels = [labels] + if len(labels) > 0: + if len(labels) == 1 and len(series) > 1: + # if one label is provided and multiple series are being plotted, + # set the same label to all data series. It maintains + # back-compatibility + labels *= len(series) + if len(series) != len(labels): + raise ValueError("The number of labels must be equal to the " + "number of expressions being plotted.\nReceived " + f"{len(series)} expressions and {len(labels)} labels") + + for s, l in zip(series, labels): + s.label = l + + if rendering_kw: + if isinstance(rendering_kw, dict): + rendering_kw = [rendering_kw] + if len(rendering_kw) == 1: + rendering_kw *= len(series) + elif len(series) != len(rendering_kw): + raise ValueError("The number of rendering dictionaries must be " + "equal to the number of expressions being plotted.\nReceived " + f"{len(series)} expressions and {len(labels)} labels") + for s, r in zip(series, rendering_kw): + s.rendering_kw = r + + +def plot_factory(*args, **kwargs): + backend = kwargs.pop("backend", "default") + if isinstance(backend, str): + if backend == "default": + matplotlib = import_module('matplotlib', + min_module_version='1.1.0', catch=(RuntimeError,)) + if matplotlib: + return MatplotlibBackend(*args, **kwargs) + return TextBackend(*args, **kwargs) + return plot_backends[backend](*args, **kwargs) + elif (type(backend) == type) and issubclass(backend, Plot): + return backend(*args, **kwargs) + else: + raise TypeError("backend must be either a string or a subclass of ``Plot``.") + + +plot_backends = { + 'matplotlib': MatplotlibBackend, + 'text': TextBackend, +} + + +####New API for plotting module #### + +# TODO: Add color arrays for plots. +# TODO: Add more plotting options for 3d plots. +# TODO: Adaptive sampling for 3D plots. + +def plot(*args, show=True, **kwargs): + """Plots a function of a single variable as a curve. + + Parameters + ========== + + args : + The first argument is the expression representing the function + of single variable to be plotted. + + The last argument is a 3-tuple denoting the range of the free + variable. e.g. ``(x, 0, 5)`` + + Typical usage examples are in the following: + + - Plotting a single expression with a single range. + ``plot(expr, range, **kwargs)`` + - Plotting a single expression with the default range (-10, 10). + ``plot(expr, **kwargs)`` + - Plotting multiple expressions with a single range. + ``plot(expr1, expr2, ..., range, **kwargs)`` + - Plotting multiple expressions with multiple ranges. + ``plot((expr1, range1), (expr2, range2), ..., **kwargs)`` + + It is best practice to specify range explicitly because default + range may change in the future if a more advanced default range + detection algorithm is implemented. + + show : bool, optional + The default value is set to ``True``. Set show to ``False`` and + the function will not display the plot. The returned instance of + the ``Plot`` class can then be used to save or display the plot + by calling the ``save()`` and ``show()`` methods respectively. + + line_color : string, or float, or function, optional + Specifies the color for the plot. + See ``Plot`` to see how to set color for the plots. + Note that by setting ``line_color``, it would be applied simultaneously + to all the series. + + title : str, optional + Title of the plot. It is set to the latex representation of + the expression, if the plot has only one expression. + + label : str, optional + The label of the expression in the plot. It will be used when + called with ``legend``. Default is the name of the expression. + e.g. ``sin(x)`` + + xlabel : str or expression, optional + Label for the x-axis. + + ylabel : str or expression, optional + Label for the y-axis. + + xscale : 'linear' or 'log', optional + Sets the scaling of the x-axis. + + yscale : 'linear' or 'log', optional + Sets the scaling of the y-axis. + + axis_center : (float, float), optional + Tuple of two floats denoting the coordinates of the center or + {'center', 'auto'} + + xlim : (float, float), optional + Denotes the x-axis limits, ``(min, max)```. + + ylim : (float, float), optional + Denotes the y-axis limits, ``(min, max)```. + + annotations : list, optional + A list of dictionaries specifying the type of annotation + required. The keys in the dictionary should be equivalent + to the arguments of the :external:mod:`matplotlib`'s + :external:meth:`~matplotlib.axes.Axes.annotate` method. + + markers : list, optional + A list of dictionaries specifying the type the markers required. + The keys in the dictionary should be equivalent to the arguments + of the :external:mod:`matplotlib`'s :external:func:`~matplotlib.pyplot.plot()` function + along with the marker related keyworded arguments. + + rectangles : list, optional + A list of dictionaries specifying the dimensions of the + rectangles to be plotted. The keys in the dictionary should be + equivalent to the arguments of the :external:mod:`matplotlib`'s + :external:class:`~matplotlib.patches.Rectangle` class. + + fill : dict, optional + A dictionary specifying the type of color filling required in + the plot. The keys in the dictionary should be equivalent to the + arguments of the :external:mod:`matplotlib`'s + :external:meth:`~matplotlib.axes.Axes.fill_between` method. + + adaptive : bool, optional + The default value for the ``adaptive`` parameter is now ``False``. + To enable adaptive sampling, set ``adaptive=True`` and specify ``n`` if uniform sampling is required. + + The plotting uses an adaptive algorithm which samples + recursively to accurately plot. The adaptive algorithm uses a + random point near the midpoint of two points that has to be + further sampled. Hence the same plots can appear slightly + different. + + depth : int, optional + Recursion depth of the adaptive algorithm. A depth of value + `n` samples a maximum of `2^{n}` points. + + If the ``adaptive`` flag is set to ``False``, this will be + ignored. + + n : int, optional + Used when the ``adaptive`` is set to ``False``. The function + is uniformly sampled at ``n`` number of points. If the ``adaptive`` + flag is set to ``True``, this will be ignored. + This keyword argument replaces ``nb_of_points``, which should be + considered deprecated. + + size : (float, float), optional + A tuple in the form (width, height) in inches to specify the size of + the overall figure. The default value is set to ``None``, meaning + the size will be set by the default backend. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.plotting import plot + >>> x = symbols('x') + + Single Plot + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> plot(x**2, (x, -5, 5)) + Plot object containing: + [0]: cartesian line: x**2 for x over (-5.0, 5.0) + + Multiple plots with single range. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> plot(x, x**2, x**3, (x, -5, 5)) + Plot object containing: + [0]: cartesian line: x for x over (-5.0, 5.0) + [1]: cartesian line: x**2 for x over (-5.0, 5.0) + [2]: cartesian line: x**3 for x over (-5.0, 5.0) + + Multiple plots with different ranges. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> plot((x**2, (x, -6, 6)), (x, (x, -5, 5))) + Plot object containing: + [0]: cartesian line: x**2 for x over (-6.0, 6.0) + [1]: cartesian line: x for x over (-5.0, 5.0) + + No adaptive sampling by default. If adaptive sampling is required, set ``adaptive=True``. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> plot(x**2, adaptive=True, n=400) + Plot object containing: + [0]: cartesian line: x**2 for x over (-10.0, 10.0) + + See Also + ======== + + Plot, LineOver1DRangeSeries + + """ + args = _plot_sympify(args) + plot_expr = _check_arguments(args, 1, 1, **kwargs) + params = kwargs.get("params", None) + free = set() + for p in plot_expr: + if not isinstance(p[1][0], str): + free |= {p[1][0]} + else: + free |= {Symbol(p[1][0])} + if params: + free = free.difference(params.keys()) + x = free.pop() if free else Symbol("x") + kwargs.setdefault('xlabel', x) + kwargs.setdefault('ylabel', Function('f')(x)) + + labels = kwargs.pop("label", []) + rendering_kw = kwargs.pop("rendering_kw", None) + series = _build_line_series(*plot_expr, **kwargs) + _set_labels(series, labels, rendering_kw) + + plots = plot_factory(*series, **kwargs) + if show: + plots.show() + return plots + + +def plot_parametric(*args, show=True, **kwargs): + """ + Plots a 2D parametric curve. + + Parameters + ========== + + args + Common specifications are: + + - Plotting a single parametric curve with a range + ``plot_parametric((expr_x, expr_y), range)`` + - Plotting multiple parametric curves with the same range + ``plot_parametric((expr_x, expr_y), ..., range)`` + - Plotting multiple parametric curves with different ranges + ``plot_parametric((expr_x, expr_y, range), ...)`` + + ``expr_x`` is the expression representing $x$ component of the + parametric function. + + ``expr_y`` is the expression representing $y$ component of the + parametric function. + + ``range`` is a 3-tuple denoting the parameter symbol, start and + stop. For example, ``(u, 0, 5)``. + + If the range is not specified, then a default range of (-10, 10) + is used. + + However, if the arguments are specified as + ``(expr_x, expr_y, range), ...``, you must specify the ranges + for each expressions manually. + + Default range may change in the future if a more advanced + algorithm is implemented. + + adaptive : bool, optional + Specifies whether to use the adaptive sampling or not. + + The default value is set to ``True``. Set adaptive to ``False`` + and specify ``n`` if uniform sampling is required. + + depth : int, optional + The recursion depth of the adaptive algorithm. A depth of + value $n$ samples a maximum of $2^n$ points. + + n : int, optional + Used when the ``adaptive`` flag is set to ``False``. Specifies the + number of the points used for the uniform sampling. + This keyword argument replaces ``nb_of_points``, which should be + considered deprecated. + + line_color : string, or float, or function, optional + Specifies the color for the plot. + See ``Plot`` to see how to set color for the plots. + Note that by setting ``line_color``, it would be applied simultaneously + to all the series. + + label : str, optional + The label of the expression in the plot. It will be used when + called with ``legend``. Default is the name of the expression. + e.g. ``sin(x)`` + + xlabel : str, optional + Label for the x-axis. + + ylabel : str, optional + Label for the y-axis. + + xscale : 'linear' or 'log', optional + Sets the scaling of the x-axis. + + yscale : 'linear' or 'log', optional + Sets the scaling of the y-axis. + + axis_center : (float, float), optional + Tuple of two floats denoting the coordinates of the center or + {'center', 'auto'} + + xlim : (float, float), optional + Denotes the x-axis limits, ``(min, max)```. + + ylim : (float, float), optional + Denotes the y-axis limits, ``(min, max)```. + + size : (float, float), optional + A tuple in the form (width, height) in inches to specify the size of + the overall figure. The default value is set to ``None``, meaning + the size will be set by the default backend. + + Examples + ======== + + .. plot:: + :context: reset + :format: doctest + :include-source: True + + >>> from sympy import plot_parametric, symbols, cos, sin + >>> u = symbols('u') + + A parametric plot with a single expression: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> plot_parametric((cos(u), sin(u)), (u, -5, 5)) + Plot object containing: + [0]: parametric cartesian line: (cos(u), sin(u)) for u over (-5.0, 5.0) + + A parametric plot with multiple expressions with the same range: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> plot_parametric((cos(u), sin(u)), (u, cos(u)), (u, -10, 10)) + Plot object containing: + [0]: parametric cartesian line: (cos(u), sin(u)) for u over (-10.0, 10.0) + [1]: parametric cartesian line: (u, cos(u)) for u over (-10.0, 10.0) + + A parametric plot with multiple expressions with different ranges + for each curve: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> plot_parametric((cos(u), sin(u), (u, -5, 5)), + ... (cos(u), u, (u, -5, 5))) + Plot object containing: + [0]: parametric cartesian line: (cos(u), sin(u)) for u over (-5.0, 5.0) + [1]: parametric cartesian line: (cos(u), u) for u over (-5.0, 5.0) + + Notes + ===== + + The plotting uses an adaptive algorithm which samples recursively to + accurately plot the curve. The adaptive algorithm uses a random point + near the midpoint of two points that has to be further sampled. + Hence, repeating the same plot command can give slightly different + results because of the random sampling. + + If there are multiple plots, then the same optional arguments are + applied to all the plots drawn in the same canvas. If you want to + set these options separately, you can index the returned ``Plot`` + object and set it. + + For example, when you specify ``line_color`` once, it would be + applied simultaneously to both series. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy import pi + >>> expr1 = (u, cos(2*pi*u)/2 + 1/2) + >>> expr2 = (u, sin(2*pi*u)/2 + 1/2) + >>> p = plot_parametric(expr1, expr2, (u, 0, 1), line_color='blue') + + If you want to specify the line color for the specific series, you + should index each item and apply the property manually. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> p[0].line_color = 'red' + >>> p.show() + + See Also + ======== + + Plot, Parametric2DLineSeries + """ + args = _plot_sympify(args) + plot_expr = _check_arguments(args, 2, 1, **kwargs) + + labels = kwargs.pop("label", []) + rendering_kw = kwargs.pop("rendering_kw", None) + series = _create_series(Parametric2DLineSeries, plot_expr, **kwargs) + _set_labels(series, labels, rendering_kw) + + plots = plot_factory(*series, **kwargs) + if show: + plots.show() + return plots + + +def plot3d_parametric_line(*args, show=True, **kwargs): + """ + Plots a 3D parametric line plot. + + Usage + ===== + + Single plot: + + ``plot3d_parametric_line(expr_x, expr_y, expr_z, range, **kwargs)`` + + If the range is not specified, then a default range of (-10, 10) is used. + + Multiple plots. + + ``plot3d_parametric_line((expr_x, expr_y, expr_z, range), ..., **kwargs)`` + + Ranges have to be specified for every expression. + + Default range may change in the future if a more advanced default range + detection algorithm is implemented. + + Arguments + ========= + + expr_x : Expression representing the function along x. + + expr_y : Expression representing the function along y. + + expr_z : Expression representing the function along z. + + range : (:class:`~.Symbol`, float, float) + A 3-tuple denoting the range of the parameter variable, e.g., (u, 0, 5). + + Keyword Arguments + ================= + + Arguments for ``Parametric3DLineSeries`` class. + + n : int + The range is uniformly sampled at ``n`` number of points. + This keyword argument replaces ``nb_of_points``, which should be + considered deprecated. + + Aesthetics: + + line_color : string, or float, or function, optional + Specifies the color for the plot. + See ``Plot`` to see how to set color for the plots. + Note that by setting ``line_color``, it would be applied simultaneously + to all the series. + + label : str + The label to the plot. It will be used when called with ``legend=True`` + to denote the function with the given label in the plot. + + If there are multiple plots, then the same series arguments are applied to + all the plots. If you want to set these options separately, you can index + the returned ``Plot`` object and set it. + + Arguments for ``Plot`` class. + + title : str + Title of the plot. + + size : (float, float), optional + A tuple in the form (width, height) in inches to specify the size of + the overall figure. The default value is set to ``None``, meaning + the size will be set by the default backend. + + Examples + ======== + + .. plot:: + :context: reset + :format: doctest + :include-source: True + + >>> from sympy import symbols, cos, sin + >>> from sympy.plotting import plot3d_parametric_line + >>> u = symbols('u') + + Single plot. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> plot3d_parametric_line(cos(u), sin(u), u, (u, -5, 5)) + Plot object containing: + [0]: 3D parametric cartesian line: (cos(u), sin(u), u) for u over (-5.0, 5.0) + + + Multiple plots. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> plot3d_parametric_line((cos(u), sin(u), u, (u, -5, 5)), + ... (sin(u), u**2, u, (u, -5, 5))) + Plot object containing: + [0]: 3D parametric cartesian line: (cos(u), sin(u), u) for u over (-5.0, 5.0) + [1]: 3D parametric cartesian line: (sin(u), u**2, u) for u over (-5.0, 5.0) + + + See Also + ======== + + Plot, Parametric3DLineSeries + + """ + args = _plot_sympify(args) + plot_expr = _check_arguments(args, 3, 1, **kwargs) + kwargs.setdefault("xlabel", "x") + kwargs.setdefault("ylabel", "y") + kwargs.setdefault("zlabel", "z") + + labels = kwargs.pop("label", []) + rendering_kw = kwargs.pop("rendering_kw", None) + series = _create_series(Parametric3DLineSeries, plot_expr, **kwargs) + _set_labels(series, labels, rendering_kw) + + plots = plot_factory(*series, **kwargs) + if show: + plots.show() + return plots + + +def _plot3d_plot_contour_helper(Series, *args, **kwargs): + """plot3d and plot_contour are structurally identical. Let's reduce + code repetition. + """ + # NOTE: if this import would be at the top-module level, it would trigger + # SymPy's optional-dependencies tests to fail. + from sympy.vector import BaseScalar + + args = _plot_sympify(args) + plot_expr = _check_arguments(args, 1, 2, **kwargs) + + free_x = set() + free_y = set() + _types = (Symbol, BaseScalar, Indexed, AppliedUndef) + for p in plot_expr: + free_x |= {p[1][0]} if isinstance(p[1][0], _types) else {Symbol(p[1][0])} + free_y |= {p[2][0]} if isinstance(p[2][0], _types) else {Symbol(p[2][0])} + x = free_x.pop() if free_x else Symbol("x") + y = free_y.pop() if free_y else Symbol("y") + kwargs.setdefault("xlabel", x) + kwargs.setdefault("ylabel", y) + kwargs.setdefault("zlabel", Function('f')(x, y)) + + # if a polar discretization is requested and automatic labelling has ben + # applied, hide the labels on the x-y axis. + if kwargs.get("is_polar", False): + if callable(kwargs["xlabel"]): + kwargs["xlabel"] = "" + if callable(kwargs["ylabel"]): + kwargs["ylabel"] = "" + + labels = kwargs.pop("label", []) + rendering_kw = kwargs.pop("rendering_kw", None) + series = _create_series(Series, plot_expr, **kwargs) + _set_labels(series, labels, rendering_kw) + plots = plot_factory(*series, **kwargs) + if kwargs.get("show", True): + plots.show() + return plots + + +def plot3d(*args, show=True, **kwargs): + """ + Plots a 3D surface plot. + + Usage + ===== + + Single plot + + ``plot3d(expr, range_x, range_y, **kwargs)`` + + If the ranges are not specified, then a default range of (-10, 10) is used. + + Multiple plot with the same range. + + ``plot3d(expr1, expr2, range_x, range_y, **kwargs)`` + + If the ranges are not specified, then a default range of (-10, 10) is used. + + Multiple plots with different ranges. + + ``plot3d((expr1, range_x, range_y), (expr2, range_x, range_y), ..., **kwargs)`` + + Ranges have to be specified for every expression. + + Default range may change in the future if a more advanced default range + detection algorithm is implemented. + + Arguments + ========= + + expr : Expression representing the function along x. + + range_x : (:class:`~.Symbol`, float, float) + A 3-tuple denoting the range of the x variable, e.g. (x, 0, 5). + + range_y : (:class:`~.Symbol`, float, float) + A 3-tuple denoting the range of the y variable, e.g. (y, 0, 5). + + Keyword Arguments + ================= + + Arguments for ``SurfaceOver2DRangeSeries`` class: + + n1 : int + The x range is sampled uniformly at ``n1`` of points. + This keyword argument replaces ``nb_of_points_x``, which should be + considered deprecated. + + n2 : int + The y range is sampled uniformly at ``n2`` of points. + This keyword argument replaces ``nb_of_points_y``, which should be + considered deprecated. + + Aesthetics: + + surface_color : Function which returns a float + Specifies the color for the surface of the plot. + See :class:`~.Plot` for more details. + + If there are multiple plots, then the same series arguments are applied to + all the plots. If you want to set these options separately, you can index + the returned ``Plot`` object and set it. + + Arguments for ``Plot`` class: + + title : str + Title of the plot. + + size : (float, float), optional + A tuple in the form (width, height) in inches to specify the size of the + overall figure. The default value is set to ``None``, meaning the size will + be set by the default backend. + + Examples + ======== + + .. plot:: + :context: reset + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.plotting import plot3d + >>> x, y = symbols('x y') + + Single plot + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> plot3d(x*y, (x, -5, 5), (y, -5, 5)) + Plot object containing: + [0]: cartesian surface: x*y for x over (-5.0, 5.0) and y over (-5.0, 5.0) + + + Multiple plots with same range + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> plot3d(x*y, -x*y, (x, -5, 5), (y, -5, 5)) + Plot object containing: + [0]: cartesian surface: x*y for x over (-5.0, 5.0) and y over (-5.0, 5.0) + [1]: cartesian surface: -x*y for x over (-5.0, 5.0) and y over (-5.0, 5.0) + + + Multiple plots with different ranges. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> plot3d((x**2 + y**2, (x, -5, 5), (y, -5, 5)), + ... (x*y, (x, -3, 3), (y, -3, 3))) + Plot object containing: + [0]: cartesian surface: x**2 + y**2 for x over (-5.0, 5.0) and y over (-5.0, 5.0) + [1]: cartesian surface: x*y for x over (-3.0, 3.0) and y over (-3.0, 3.0) + + + See Also + ======== + + Plot, SurfaceOver2DRangeSeries + + """ + kwargs.setdefault("show", show) + return _plot3d_plot_contour_helper( + SurfaceOver2DRangeSeries, *args, **kwargs) + + +def plot3d_parametric_surface(*args, show=True, **kwargs): + """ + Plots a 3D parametric surface plot. + + Explanation + =========== + + Single plot. + + ``plot3d_parametric_surface(expr_x, expr_y, expr_z, range_u, range_v, **kwargs)`` + + If the ranges is not specified, then a default range of (-10, 10) is used. + + Multiple plots. + + ``plot3d_parametric_surface((expr_x, expr_y, expr_z, range_u, range_v), ..., **kwargs)`` + + Ranges have to be specified for every expression. + + Default range may change in the future if a more advanced default range + detection algorithm is implemented. + + Arguments + ========= + + expr_x : Expression representing the function along ``x``. + + expr_y : Expression representing the function along ``y``. + + expr_z : Expression representing the function along ``z``. + + range_u : (:class:`~.Symbol`, float, float) + A 3-tuple denoting the range of the u variable, e.g. (u, 0, 5). + + range_v : (:class:`~.Symbol`, float, float) + A 3-tuple denoting the range of the v variable, e.g. (v, 0, 5). + + Keyword Arguments + ================= + + Arguments for ``ParametricSurfaceSeries`` class: + + n1 : int + The ``u`` range is sampled uniformly at ``n1`` of points. + This keyword argument replaces ``nb_of_points_u``, which should be + considered deprecated. + + n2 : int + The ``v`` range is sampled uniformly at ``n2`` of points. + This keyword argument replaces ``nb_of_points_v``, which should be + considered deprecated. + + Aesthetics: + + surface_color : Function which returns a float + Specifies the color for the surface of the plot. See + :class:`~Plot` for more details. + + If there are multiple plots, then the same series arguments are applied for + all the plots. If you want to set these options separately, you can index + the returned ``Plot`` object and set it. + + + Arguments for ``Plot`` class: + + title : str + Title of the plot. + + size : (float, float), optional + A tuple in the form (width, height) in inches to specify the size of the + overall figure. The default value is set to ``None``, meaning the size will + be set by the default backend. + + Examples + ======== + + .. plot:: + :context: reset + :format: doctest + :include-source: True + + >>> from sympy import symbols, cos, sin + >>> from sympy.plotting import plot3d_parametric_surface + >>> u, v = symbols('u v') + + Single plot. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> plot3d_parametric_surface(cos(u + v), sin(u - v), u - v, + ... (u, -5, 5), (v, -5, 5)) + Plot object containing: + [0]: parametric cartesian surface: (cos(u + v), sin(u - v), u - v) for u over (-5.0, 5.0) and v over (-5.0, 5.0) + + + See Also + ======== + + Plot, ParametricSurfaceSeries + + """ + + args = _plot_sympify(args) + plot_expr = _check_arguments(args, 3, 2, **kwargs) + kwargs.setdefault("xlabel", "x") + kwargs.setdefault("ylabel", "y") + kwargs.setdefault("zlabel", "z") + + labels = kwargs.pop("label", []) + rendering_kw = kwargs.pop("rendering_kw", None) + series = _create_series(ParametricSurfaceSeries, plot_expr, **kwargs) + _set_labels(series, labels, rendering_kw) + + plots = plot_factory(*series, **kwargs) + if show: + plots.show() + return plots + +def plot_contour(*args, show=True, **kwargs): + """ + Draws contour plot of a function + + Usage + ===== + + Single plot + + ``plot_contour(expr, range_x, range_y, **kwargs)`` + + If the ranges are not specified, then a default range of (-10, 10) is used. + + Multiple plot with the same range. + + ``plot_contour(expr1, expr2, range_x, range_y, **kwargs)`` + + If the ranges are not specified, then a default range of (-10, 10) is used. + + Multiple plots with different ranges. + + ``plot_contour((expr1, range_x, range_y), (expr2, range_x, range_y), ..., **kwargs)`` + + Ranges have to be specified for every expression. + + Default range may change in the future if a more advanced default range + detection algorithm is implemented. + + Arguments + ========= + + expr : Expression representing the function along x. + + range_x : (:class:`Symbol`, float, float) + A 3-tuple denoting the range of the x variable, e.g. (x, 0, 5). + + range_y : (:class:`Symbol`, float, float) + A 3-tuple denoting the range of the y variable, e.g. (y, 0, 5). + + Keyword Arguments + ================= + + Arguments for ``ContourSeries`` class: + + n1 : int + The x range is sampled uniformly at ``n1`` of points. + This keyword argument replaces ``nb_of_points_x``, which should be + considered deprecated. + + n2 : int + The y range is sampled uniformly at ``n2`` of points. + This keyword argument replaces ``nb_of_points_y``, which should be + considered deprecated. + + Aesthetics: + + surface_color : Function which returns a float + Specifies the color for the surface of the plot. See + :class:`sympy.plotting.Plot` for more details. + + If there are multiple plots, then the same series arguments are applied to + all the plots. If you want to set these options separately, you can index + the returned ``Plot`` object and set it. + + Arguments for ``Plot`` class: + + title : str + Title of the plot. + + size : (float, float), optional + A tuple in the form (width, height) in inches to specify the size of + the overall figure. The default value is set to ``None``, meaning + the size will be set by the default backend. + + See Also + ======== + + Plot, ContourSeries + + """ + kwargs.setdefault("show", show) + return _plot3d_plot_contour_helper(ContourSeries, *args, **kwargs) + + +def check_arguments(args, expr_len, nb_of_free_symbols): + """ + Checks the arguments and converts into tuples of the + form (exprs, ranges). + + Examples + ======== + + .. plot:: + :context: reset + :format: doctest + :include-source: True + + >>> from sympy import cos, sin, symbols + >>> from sympy.plotting.plot import check_arguments + >>> x = symbols('x') + >>> check_arguments([cos(x), sin(x)], 2, 1) + [(cos(x), sin(x), (x, -10, 10))] + + >>> check_arguments([x, x**2], 1, 1) + [(x, (x, -10, 10)), (x**2, (x, -10, 10))] + """ + if not args: + return [] + if expr_len > 1 and isinstance(args[0], Expr): + # Multiple expressions same range. + # The arguments are tuples when the expression length is + # greater than 1. + if len(args) < expr_len: + raise ValueError("len(args) should not be less than expr_len") + for i in range(len(args)): + if isinstance(args[i], Tuple): + break + else: + i = len(args) + 1 + + exprs = Tuple(*args[:i]) + free_symbols = list(set().union(*[e.free_symbols for e in exprs])) + if len(args) == expr_len + nb_of_free_symbols: + #Ranges given + plots = [exprs + Tuple(*args[expr_len:])] + else: + default_range = Tuple(-10, 10) + ranges = [] + for symbol in free_symbols: + ranges.append(Tuple(symbol) + default_range) + + for i in range(len(free_symbols) - nb_of_free_symbols): + ranges.append(Tuple(Dummy()) + default_range) + plots = [exprs + Tuple(*ranges)] + return plots + + if isinstance(args[0], Expr) or (isinstance(args[0], Tuple) and + len(args[0]) == expr_len and + expr_len != 3): + # Cannot handle expressions with number of expression = 3. It is + # not possible to differentiate between expressions and ranges. + #Series of plots with same range + for i in range(len(args)): + if isinstance(args[i], Tuple) and len(args[i]) != expr_len: + break + if not isinstance(args[i], Tuple): + args[i] = Tuple(args[i]) + else: + i = len(args) + 1 + + exprs = args[:i] + assert all(isinstance(e, Expr) for expr in exprs for e in expr) + free_symbols = list(set().union(*[e.free_symbols for expr in exprs + for e in expr])) + + if len(free_symbols) > nb_of_free_symbols: + raise ValueError("The number of free_symbols in the expression " + "is greater than %d" % nb_of_free_symbols) + if len(args) == i + nb_of_free_symbols and isinstance(args[i], Tuple): + ranges = Tuple(*list(args[ + i:i + nb_of_free_symbols])) + plots = [expr + ranges for expr in exprs] + return plots + else: + # Use default ranges. + default_range = Tuple(-10, 10) + ranges = [] + for symbol in free_symbols: + ranges.append(Tuple(symbol) + default_range) + + for i in range(nb_of_free_symbols - len(free_symbols)): + ranges.append(Tuple(Dummy()) + default_range) + ranges = Tuple(*ranges) + plots = [expr + ranges for expr in exprs] + return plots + + elif isinstance(args[0], Tuple) and len(args[0]) == expr_len + nb_of_free_symbols: + # Multiple plots with different ranges. + for arg in args: + for i in range(expr_len): + if not isinstance(arg[i], Expr): + raise ValueError("Expected an expression, given %s" % + str(arg[i])) + for i in range(nb_of_free_symbols): + if not len(arg[i + expr_len]) == 3: + raise ValueError("The ranges should be a tuple of " + "length 3, got %s" % str(arg[i + expr_len])) + return args diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/plot_implicit.py b/.venv/lib/python3.13/site-packages/sympy/plotting/plot_implicit.py new file mode 100644 index 0000000000000000000000000000000000000000..5dceaf0699a2e6d3ff0bc30f415721918724cad5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/plot_implicit.py @@ -0,0 +1,233 @@ +"""Implicit plotting module for SymPy. + +Explanation +=========== + +The module implements a data series called ImplicitSeries which is used by +``Plot`` class to plot implicit plots for different backends. The module, +by default, implements plotting using interval arithmetic. It switches to a +fall back algorithm if the expression cannot be plotted using interval arithmetic. +It is also possible to specify to use the fall back algorithm for all plots. + +Boolean combinations of expressions cannot be plotted by the fall back +algorithm. + +See Also +======== + +sympy.plotting.plot + +References +========== + +.. [1] Jeffrey Allen Tupper. Reliable Two-Dimensional Graphing Methods for +Mathematical Formulae with Two Free Variables. + +.. [2] Jeffrey Allen Tupper. Graphing Equations with Generalized Interval +Arithmetic. Master's thesis. University of Toronto, 1996 + +""" + + +from sympy.core.containers import Tuple +from sympy.core.symbol import (Dummy, Symbol) +from sympy.polys.polyutils import _sort_gens +from sympy.plotting.series import ImplicitSeries, _set_discretization_points +from sympy.plotting.plot import plot_factory +from sympy.utilities.decorator import doctest_depends_on +from sympy.utilities.iterables import flatten + + +__doctest_requires__ = {'plot_implicit': ['matplotlib']} + + +@doctest_depends_on(modules=('matplotlib',)) +def plot_implicit(expr, x_var=None, y_var=None, adaptive=True, depth=0, + n=300, line_color="blue", show=True, **kwargs): + """A plot function to plot implicit equations / inequalities. + + Arguments + ========= + + - expr : The equation / inequality that is to be plotted. + - x_var (optional) : symbol to plot on x-axis or tuple giving symbol + and range as ``(symbol, xmin, xmax)`` + - y_var (optional) : symbol to plot on y-axis or tuple giving symbol + and range as ``(symbol, ymin, ymax)`` + + If neither ``x_var`` nor ``y_var`` are given then the free symbols in the + expression will be assigned in the order they are sorted. + + The following keyword arguments can also be used: + + - ``adaptive`` Boolean. The default value is set to True. It has to be + set to False if you want to use a mesh grid. + + - ``depth`` integer. The depth of recursion for adaptive mesh grid. + Default value is 0. Takes value in the range (0, 4). + + - ``n`` integer. The number of points if adaptive mesh grid is not + used. Default value is 300. This keyword argument replaces ``points``, + which should be considered deprecated. + + - ``show`` Boolean. Default value is True. If set to False, the plot will + not be shown. See ``Plot`` for further information. + + - ``title`` string. The title for the plot. + + - ``xlabel`` string. The label for the x-axis + + - ``ylabel`` string. The label for the y-axis + + Aesthetics options: + + - ``line_color``: float or string. Specifies the color for the plot. + See ``Plot`` to see how to set color for the plots. + Default value is "Blue" + + plot_implicit, by default, uses interval arithmetic to plot functions. If + the expression cannot be plotted using interval arithmetic, it defaults to + a generating a contour using a mesh grid of fixed number of points. By + setting adaptive to False, you can force plot_implicit to use the mesh + grid. The mesh grid method can be effective when adaptive plotting using + interval arithmetic, fails to plot with small line width. + + Examples + ======== + + Plot expressions: + + .. plot:: + :context: reset + :format: doctest + :include-source: True + + >>> from sympy import plot_implicit, symbols, Eq, And + >>> x, y = symbols('x y') + + Without any ranges for the symbols in the expression: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> p1 = plot_implicit(Eq(x**2 + y**2, 5)) + + With the range for the symbols: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> p2 = plot_implicit( + ... Eq(x**2 + y**2, 3), (x, -3, 3), (y, -3, 3)) + + With depth of recursion as argument: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> p3 = plot_implicit( + ... Eq(x**2 + y**2, 5), (x, -4, 4), (y, -4, 4), depth = 2) + + Using mesh grid and not using adaptive meshing: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> p4 = plot_implicit( + ... Eq(x**2 + y**2, 5), (x, -5, 5), (y, -2, 2), + ... adaptive=False) + + Using mesh grid without using adaptive meshing with number of points + specified: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> p5 = plot_implicit( + ... Eq(x**2 + y**2, 5), (x, -5, 5), (y, -2, 2), + ... adaptive=False, n=400) + + Plotting regions: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> p6 = plot_implicit(y > x**2) + + Plotting Using boolean conjunctions: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> p7 = plot_implicit(And(y > x, y > -x)) + + When plotting an expression with a single variable (y - 1, for example), + specify the x or the y variable explicitly: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> p8 = plot_implicit(y - 1, y_var=y) + >>> p9 = plot_implicit(x - 1, x_var=x) + """ + + xyvar = [i for i in (x_var, y_var) if i is not None] + free_symbols = expr.free_symbols + range_symbols = Tuple(*flatten(xyvar)).free_symbols + undeclared = free_symbols - range_symbols + if len(free_symbols & range_symbols) > 2: + raise NotImplementedError("Implicit plotting is not implemented for " + "more than 2 variables") + + #Create default ranges if the range is not provided. + default_range = Tuple(-5, 5) + def _range_tuple(s): + if isinstance(s, Symbol): + return Tuple(s) + default_range + if len(s) == 3: + return Tuple(*s) + raise ValueError('symbol or `(symbol, min, max)` expected but got %s' % s) + + if len(xyvar) == 0: + xyvar = list(_sort_gens(free_symbols)) + var_start_end_x = _range_tuple(xyvar[0]) + x = var_start_end_x[0] + if len(xyvar) != 2: + if x in undeclared or not undeclared: + xyvar.append(Dummy('f(%s)' % x.name)) + else: + xyvar.append(undeclared.pop()) + var_start_end_y = _range_tuple(xyvar[1]) + + kwargs = _set_discretization_points(kwargs, ImplicitSeries) + series_argument = ImplicitSeries( + expr, var_start_end_x, var_start_end_y, + adaptive=adaptive, depth=depth, + n=n, line_color=line_color) + + #set the x and y limits + kwargs['xlim'] = tuple(float(x) for x in var_start_end_x[1:]) + kwargs['ylim'] = tuple(float(y) for y in var_start_end_y[1:]) + # set the x and y labels + kwargs.setdefault('xlabel', var_start_end_x[0]) + kwargs.setdefault('ylabel', var_start_end_y[0]) + p = plot_factory(series_argument, **kwargs) + if show: + p.show() + return p diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/plotgrid.py b/.venv/lib/python3.13/site-packages/sympy/plotting/plotgrid.py new file mode 100644 index 0000000000000000000000000000000000000000..8ff811c591e762275df1a0e3a221d05920d1804e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/plotgrid.py @@ -0,0 +1,188 @@ + +from sympy.external import import_module +import sympy.plotting.backends.base_backend as base_backend + + +# N.B. +# When changing the minimum module version for matplotlib, please change +# the same in the `SymPyDocTestFinder`` in `sympy/testing/runtests.py` + + +__doctest_requires__ = { + ("PlotGrid",): ["matplotlib"], +} + + +class PlotGrid: + """This class helps to plot subplots from already created SymPy plots + in a single figure. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.plotting import plot, plot3d, PlotGrid + >>> x, y = symbols('x, y') + >>> p1 = plot(x, x**2, x**3, (x, -5, 5)) + >>> p2 = plot((x**2, (x, -6, 6)), (x, (x, -5, 5))) + >>> p3 = plot(x**3, (x, -5, 5)) + >>> p4 = plot3d(x*y, (x, -5, 5), (y, -5, 5)) + + Plotting vertically in a single line: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> PlotGrid(2, 1, p1, p2) + PlotGrid object containing: + Plot[0]:Plot object containing: + [0]: cartesian line: x for x over (-5.0, 5.0) + [1]: cartesian line: x**2 for x over (-5.0, 5.0) + [2]: cartesian line: x**3 for x over (-5.0, 5.0) + Plot[1]:Plot object containing: + [0]: cartesian line: x**2 for x over (-6.0, 6.0) + [1]: cartesian line: x for x over (-5.0, 5.0) + + Plotting horizontally in a single line: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> PlotGrid(1, 3, p2, p3, p4) + PlotGrid object containing: + Plot[0]:Plot object containing: + [0]: cartesian line: x**2 for x over (-6.0, 6.0) + [1]: cartesian line: x for x over (-5.0, 5.0) + Plot[1]:Plot object containing: + [0]: cartesian line: x**3 for x over (-5.0, 5.0) + Plot[2]:Plot object containing: + [0]: cartesian surface: x*y for x over (-5.0, 5.0) and y over (-5.0, 5.0) + + Plotting in a grid form: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> PlotGrid(2, 2, p1, p2, p3, p4) + PlotGrid object containing: + Plot[0]:Plot object containing: + [0]: cartesian line: x for x over (-5.0, 5.0) + [1]: cartesian line: x**2 for x over (-5.0, 5.0) + [2]: cartesian line: x**3 for x over (-5.0, 5.0) + Plot[1]:Plot object containing: + [0]: cartesian line: x**2 for x over (-6.0, 6.0) + [1]: cartesian line: x for x over (-5.0, 5.0) + Plot[2]:Plot object containing: + [0]: cartesian line: x**3 for x over (-5.0, 5.0) + Plot[3]:Plot object containing: + [0]: cartesian surface: x*y for x over (-5.0, 5.0) and y over (-5.0, 5.0) + + """ + def __init__(self, nrows, ncolumns, *args, show=True, size=None, **kwargs): + """ + Parameters + ========== + + nrows : + The number of rows that should be in the grid of the + required subplot. + ncolumns : + The number of columns that should be in the grid + of the required subplot. + + nrows and ncolumns together define the required grid. + + Arguments + ========= + + A list of predefined plot objects entered in a row-wise sequence + i.e. plot objects which are to be in the top row of the required + grid are written first, then the second row objects and so on + + Keyword arguments + ================= + + show : Boolean + The default value is set to ``True``. Set show to ``False`` and + the function will not display the subplot. The returned instance + of the ``PlotGrid`` class can then be used to save or display the + plot by calling the ``save()`` and ``show()`` methods + respectively. + size : (float, float), optional + A tuple in the form (width, height) in inches to specify the size of + the overall figure. The default value is set to ``None``, meaning + the size will be set by the default backend. + """ + self.matplotlib = import_module('matplotlib', + import_kwargs={'fromlist': ['pyplot', 'cm', 'collections']}, + min_module_version='1.1.0', catch=(RuntimeError,)) + self.nrows = nrows + self.ncolumns = ncolumns + self._series = [] + self._fig = None + self.args = args + for arg in args: + self._series.append(arg._series) + self.size = size + if show and self.matplotlib: + self.show() + + def _create_figure(self): + gs = self.matplotlib.gridspec.GridSpec(self.nrows, self.ncolumns) + mapping = {} + c = 0 + for i in range(self.nrows): + for j in range(self.ncolumns): + if c < len(self.args): + mapping[gs[i, j]] = self.args[c] + c += 1 + + kw = {} if not self.size else {"figsize": self.size} + self._fig = self.matplotlib.pyplot.figure(**kw) + for spec, p in mapping.items(): + kw = ({"projection": "3d"} if (len(p._series) > 0 and + p._series[0].is_3D) else {}) + cur_ax = self._fig.add_subplot(spec, **kw) + p._plotgrid_fig = self._fig + p._plotgrid_ax = cur_ax + p.process_series() + + @property + def fig(self): + if not self._fig: + self._create_figure() + return self._fig + + @property + def _backend(self): + return self + + def close(self): + self.matplotlib.pyplot.close(self.fig) + + def show(self): + if base_backend._show: + self.fig.tight_layout() + self.matplotlib.pyplot.show() + else: + self.close() + + def save(self, path): + self.fig.savefig(path) + + def __str__(self): + plot_strs = [('Plot[%d]:' % i) + str(plot) + for i, plot in enumerate(self.args)] + + return 'PlotGrid object containing:\n' + '\n'.join(plot_strs) diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/series.py b/.venv/lib/python3.13/site-packages/sympy/plotting/series.py new file mode 100644 index 0000000000000000000000000000000000000000..ddd64116277668389fb8defc8289543667d2c9e8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/series.py @@ -0,0 +1,2591 @@ +### The base class for all series +from collections.abc import Callable +from sympy.calculus.util import continuous_domain +from sympy.concrete import Sum, Product +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.function import arity +from sympy.core.sorting import default_sort_key +from sympy.core.symbol import Symbol +from sympy.functions import atan2, zeta, frac, ceiling, floor, im +from sympy.core.relational import (Equality, GreaterThan, + LessThan, Relational, Ne) +from sympy.core.sympify import sympify +from sympy.external import import_module +from sympy.logic.boolalg import BooleanFunction +from sympy.plotting.utils import _get_free_symbols, extract_solution +from sympy.printing.latex import latex +from sympy.printing.pycode import PythonCodePrinter +from sympy.printing.precedence import precedence +from sympy.sets.sets import Set, Interval, Union +from sympy.simplify.simplify import nsimplify +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.lambdify import lambdify +from .intervalmath import interval +import warnings + + +class IntervalMathPrinter(PythonCodePrinter): + """A printer to be used inside `plot_implicit` when `adaptive=True`, + in which case the interval arithmetic module is going to be used, which + requires the following edits. + """ + def _print_And(self, expr): + PREC = precedence(expr) + return " & ".join(self.parenthesize(a, PREC) + for a in sorted(expr.args, key=default_sort_key)) + + def _print_Or(self, expr): + PREC = precedence(expr) + return " | ".join(self.parenthesize(a, PREC) + for a in sorted(expr.args, key=default_sort_key)) + + +def _uniform_eval(f1, f2, *args, modules=None, + force_real_eval=False, has_sum=False): + """ + Note: this is an experimental function, as such it is prone to changes. + Please, do not use it in your code. + """ + np = import_module('numpy') + + def wrapper_func(func, *args): + try: + return complex(func(*args)) + except (ZeroDivisionError, OverflowError): + return complex(np.nan, np.nan) + + # NOTE: np.vectorize is much slower than numpy vectorized operations. + # However, this modules must be able to evaluate functions also with + # mpmath or sympy. + wrapper_func = np.vectorize(wrapper_func, otypes=[complex]) + + def _eval_with_sympy(err=None): + if f2 is None: + msg = "Impossible to evaluate the provided numerical function" + if err is None: + msg += "." + else: + msg += "because the following exception was raised:\n" + "{}: {}".format(type(err).__name__, err) + raise RuntimeError(msg) + if err: + warnings.warn( + "The evaluation with %s failed.\n" % ( + "NumPy/SciPy" if not modules else modules) + + "{}: {}\n".format(type(err).__name__, err) + + "Trying to evaluate the expression with Sympy, but it might " + "be a slow operation." + ) + return wrapper_func(f2, *args) + + if modules == "sympy": + return _eval_with_sympy() + + try: + return wrapper_func(f1, *args) + except Exception as err: + return _eval_with_sympy(err) + + +def _adaptive_eval(f, x): + """Evaluate f(x) with an adaptive algorithm. Post-process the result. + If a symbolic expression is evaluated with SymPy, it might returns + another symbolic expression, containing additions, ... + Force evaluation to a float. + + Parameters + ========== + f : callable + x : float + """ + np = import_module('numpy') + + y = f(x) + if isinstance(y, Expr) and (not y.is_Number): + y = y.evalf() + y = complex(y) + if y.imag > 1e-08: + return np.nan + return y.real + + +def _get_wrapper_for_expr(ret): + wrapper = "%s" + if ret == "real": + wrapper = "re(%s)" + elif ret == "imag": + wrapper = "im(%s)" + elif ret == "abs": + wrapper = "abs(%s)" + elif ret == "arg": + wrapper = "arg(%s)" + return wrapper + + +class BaseSeries: + """Base class for the data objects containing stuff to be plotted. + + Notes + ===== + + The backend should check if it supports the data series that is given. + (e.g. TextBackend supports only LineOver1DRangeSeries). + It is the backend responsibility to know how to use the class of + data series that is given. + + Some data series classes are grouped (using a class attribute like is_2Dline) + according to the api they present (based only on convention). The backend is + not obliged to use that api (e.g. LineOver1DRangeSeries belongs to the + is_2Dline group and presents the get_points method, but the + TextBackend does not use the get_points method). + + BaseSeries + """ + + # Some flags follow. The rationale for using flags instead of checking base + # classes is that setting multiple flags is simpler than multiple + # inheritance. + + is_2Dline = False + # Some of the backends expect: + # - get_points returning 1D np.arrays list_x, list_y + # - get_color_array returning 1D np.array (done in Line2DBaseSeries) + # with the colors calculated at the points from get_points + + is_3Dline = False + # Some of the backends expect: + # - get_points returning 1D np.arrays list_x, list_y, list_y + # - get_color_array returning 1D np.array (done in Line2DBaseSeries) + # with the colors calculated at the points from get_points + + is_3Dsurface = False + # Some of the backends expect: + # - get_meshes returning mesh_x, mesh_y, mesh_z (2D np.arrays) + # - get_points an alias for get_meshes + + is_contour = False + # Some of the backends expect: + # - get_meshes returning mesh_x, mesh_y, mesh_z (2D np.arrays) + # - get_points an alias for get_meshes + + is_implicit = False + # Some of the backends expect: + # - get_meshes returning mesh_x (1D array), mesh_y(1D array, + # mesh_z (2D np.arrays) + # - get_points an alias for get_meshes + # Different from is_contour as the colormap in backend will be + # different + + is_interactive = False + # An interactive series can update its data. + + is_parametric = False + # The calculation of aesthetics expects: + # - get_parameter_points returning one or two np.arrays (1D or 2D) + # used for calculation aesthetics + + is_generic = False + # Represent generic user-provided numerical data + + is_vector = False + is_2Dvector = False + is_3Dvector = False + # Represents a 2D or 3D vector data series + + _N = 100 + # default number of discretization points for uniform sampling. Each + # subclass can set its number. + + def __init__(self, *args, **kwargs): + kwargs = _set_discretization_points(kwargs.copy(), type(self)) + # discretize the domain using only integer numbers + self.only_integers = kwargs.get("only_integers", False) + # represents the evaluation modules to be used by lambdify + self.modules = kwargs.get("modules", None) + # plot functions might create data series that might not be useful to + # be shown on the legend, for example wireframe lines on 3D plots. + self.show_in_legend = kwargs.get("show_in_legend", True) + # line and surface series can show data with a colormap, hence a + # colorbar is essential to understand the data. However, sometime it + # is useful to hide it on series-by-series base. The following keyword + # controls whether the series should show a colorbar or not. + self.colorbar = kwargs.get("colorbar", True) + # Some series might use a colormap as default coloring. Setting this + # attribute to False will inform the backends to use solid color. + self.use_cm = kwargs.get("use_cm", False) + # If True, the backend will attempt to render it on a polar-projection + # axis, or using a polar discretization if a 3D plot is requested + self.is_polar = kwargs.get("is_polar", kwargs.get("polar", False)) + # If True, the rendering will use points, not lines. + self.is_point = kwargs.get("is_point", kwargs.get("point", False)) + # some backend is able to render latex, other needs standard text + self._label = self._latex_label = "" + + self._ranges = [] + self._n = [ + int(kwargs.get("n1", self._N)), + int(kwargs.get("n2", self._N)), + int(kwargs.get("n3", self._N)) + ] + self._scales = [ + kwargs.get("xscale", "linear"), + kwargs.get("yscale", "linear"), + kwargs.get("zscale", "linear") + ] + + # enable interactive widget plots + self._params = kwargs.get("params", {}) + if not isinstance(self._params, dict): + raise TypeError("`params` must be a dictionary mapping symbols " + "to numeric values.") + if len(self._params) > 0: + self.is_interactive = True + + # contains keyword arguments that will be passed to the rendering + # function of the chosen plotting library + self.rendering_kw = kwargs.get("rendering_kw", {}) + + # numerical transformation functions to be applied to the output data: + # x, y, z (coordinates), p (parameter on parametric plots) + self._tx = kwargs.get("tx", None) + self._ty = kwargs.get("ty", None) + self._tz = kwargs.get("tz", None) + self._tp = kwargs.get("tp", None) + if not all(callable(t) or (t is None) for t in + [self._tx, self._ty, self._tz, self._tp]): + raise TypeError("`tx`, `ty`, `tz`, `tp` must be functions.") + + # list of numerical functions representing the expressions to evaluate + self._functions = [] + # signature for the numerical functions + self._signature = [] + # some expressions don't like to be evaluated over complex data. + # if that's the case, set this to True + self._force_real_eval = kwargs.get("force_real_eval", None) + # this attribute will eventually contain a dictionary with the + # discretized ranges + self._discretized_domain = None + # whether the series contains any interactive range, which is a range + # where the minimum and maximum values can be changed with an + # interactive widget + self._interactive_ranges = False + # NOTE: consider a generic summation, for example: + # s = Sum(cos(pi * x), (x, 1, y)) + # This gets lambdified to something: + # sum(cos(pi*x) for x in range(1, y+1)) + # Hence, y needs to be an integer, otherwise it raises: + # TypeError: 'complex' object cannot be interpreted as an integer + # This list will contains symbols that are upper bound to summations + # or products + self._needs_to_be_int = [] + # a color function will be responsible to set the line/surface color + # according to some logic. Each data series will et an appropriate + # default value. + self.color_func = None + # NOTE: color_func usually receives numerical functions that are going + # to be evaluated over the coordinates of the computed points (or the + # discretized meshes). + # However, if an expression is given to color_func, then it will be + # lambdified with symbols in self._signature, and it will be evaluated + # with the same data used to evaluate the plotted expression. + self._eval_color_func_with_signature = False + + def _block_lambda_functions(self, *exprs): + """Some data series can be used to plot numerical functions, others + cannot. Execute this method inside the `__init__` to prevent the + processing of numerical functions. + """ + if any(callable(e) for e in exprs): + raise TypeError(type(self).__name__ + " requires a symbolic " + "expression.") + + def _check_fs(self): + """ Checks if there are enough parameters and free symbols. + """ + exprs, ranges = self.expr, self.ranges + params, label = self.params, self.label + exprs = exprs if hasattr(exprs, "__iter__") else [exprs] + if any(callable(e) for e in exprs): + return + + # from the expression's free symbols, remove the ones used in + # the parameters and the ranges + fs = _get_free_symbols(exprs) + fs = fs.difference(params.keys()) + if ranges is not None: + fs = fs.difference([r[0] for r in ranges]) + + if len(fs) > 0: + raise ValueError( + "Incompatible expression and parameters.\n" + + "Expression: {}\n".format( + (exprs, ranges, label) if ranges is not None else (exprs, label)) + + "params: {}\n".format(params) + + "Specify what these symbols represent: {}\n".format(fs) + + "Are they ranges or parameters?" + ) + + # verify that all symbols are known (they either represent plotting + # ranges or parameters) + range_symbols = [r[0] for r in ranges] + for r in ranges: + fs = set().union(*[e.free_symbols for e in r[1:]]) + if any(t in fs for t in range_symbols): + # ranges can't depend on each other, for example this are + # not allowed: + # (x, 0, y), (y, 0, 3) + # (x, 0, y), (y, x + 2, 3) + raise ValueError("Range symbols can't be included into " + "minimum and maximum of a range. " + "Received range: %s" % str(r)) + if len(fs) > 0: + self._interactive_ranges = True + remaining_fs = fs.difference(params.keys()) + if len(remaining_fs) > 0: + raise ValueError( + "Unknown symbols found in plotting range: %s. " % (r,) + + "Are the following parameters? %s" % remaining_fs) + + def _create_lambda_func(self): + """Create the lambda functions to be used by the uniform meshing + strategy. + + Notes + ===== + The old sympy.plotting used experimental_lambdify. It created one + lambda function each time an evaluation was requested. If that failed, + it went on to create a different lambda function and evaluated it, + and so on. + + This new module changes strategy: it creates right away the default + lambda function as well as the backup one. The reason is that the + series could be interactive, hence the numerical function will be + evaluated multiple times. So, let's create the functions just once. + + This approach works fine for the majority of cases, in which the + symbolic expression is relatively short, hence the lambdification + is fast. If the expression is very long, this approach takes twice + the time to create the lambda functions. Be aware of that! + """ + exprs = self.expr if hasattr(self.expr, "__iter__") else [self.expr] + if not any(callable(e) for e in exprs): + fs = _get_free_symbols(exprs) + self._signature = sorted(fs, key=lambda t: t.name) + + # Generate a list of lambda functions, two for each expression: + # 1. the default one. + # 2. the backup one, in case of failures with the default one. + self._functions = [] + for e in exprs: + # TODO: set cse=True once this issue is solved: + # https://github.com/sympy/sympy/issues/24246 + self._functions.append([ + lambdify(self._signature, e, modules=self.modules), + lambdify(self._signature, e, modules="sympy", dummify=True), + ]) + else: + self._signature = sorted([r[0] for r in self.ranges], key=lambda t: t.name) + self._functions = [(e, None) for e in exprs] + + # deal with symbolic color_func + if isinstance(self.color_func, Expr): + self.color_func = lambdify(self._signature, self.color_func) + self._eval_color_func_with_signature = True + + def _update_range_value(self, t): + """If the value of a plotting range is a symbolic expression, + substitute the parameters in order to get a numerical value. + """ + if not self._interactive_ranges: + return complex(t) + return complex(t.subs(self.params)) + + def _create_discretized_domain(self): + """Discretize the ranges for uniform meshing strategy. + """ + # NOTE: the goal is to create a dictionary stored in + # self._discretized_domain, mapping symbols to a numpy array + # representing the discretization + discr_symbols = [] + discretizations = [] + + # create a 1D discretization + for i, r in enumerate(self.ranges): + discr_symbols.append(r[0]) + c_start = self._update_range_value(r[1]) + c_end = self._update_range_value(r[2]) + start = c_start.real if c_start.imag == c_end.imag == 0 else c_start + end = c_end.real if c_start.imag == c_end.imag == 0 else c_end + needs_integer_discr = self.only_integers or (r[0] in self._needs_to_be_int) + d = BaseSeries._discretize(start, end, self.n[i], + scale=self.scales[i], + only_integers=needs_integer_discr) + + if ((not self._force_real_eval) and (not needs_integer_discr) and + (d.dtype != "complex")): + d = d + 1j * c_start.imag + + if needs_integer_discr: + d = d.astype(int) + + discretizations.append(d) + + # create 2D or 3D + self._create_discretized_domain_helper(discr_symbols, discretizations) + + def _create_discretized_domain_helper(self, discr_symbols, discretizations): + """Create 2D or 3D discretized grids. + + Subclasses should override this method in order to implement a + different behaviour. + """ + np = import_module('numpy') + + # discretization suitable for 2D line plots, 3D surface plots, + # contours plots, vector plots + # NOTE: why indexing='ij'? Because it produces consistent results with + # np.mgrid. This is important as Mayavi requires this indexing + # to correctly compute 3D streamlines. While VTK is able to compute + # streamlines regardless of the indexing, with indexing='xy' it + # produces "strange" results with "voids" into the + # discretization volume. indexing='ij' solves the problem. + # Also note that matplotlib 2D streamlines requires indexing='xy'. + indexing = "xy" + if self.is_3Dvector or (self.is_3Dsurface and self.is_implicit): + indexing = "ij" + meshes = np.meshgrid(*discretizations, indexing=indexing) + self._discretized_domain = dict(zip(discr_symbols, meshes)) + + def _evaluate(self, cast_to_real=True): + """Evaluation of the symbolic expression (or expressions) with the + uniform meshing strategy, based on current values of the parameters. + """ + np = import_module('numpy') + + # create lambda functions + if not self._functions: + self._create_lambda_func() + # create (or update) the discretized domain + if (not self._discretized_domain) or self._interactive_ranges: + self._create_discretized_domain() + # ensure that discretized domains are returned with the proper order + discr = [self._discretized_domain[s[0]] for s in self.ranges] + + args = self._aggregate_args() + + results = [] + for f in self._functions: + r = _uniform_eval(*f, *args) + # the evaluation might produce an int/float. Need this correction. + r = self._correct_shape(np.array(r), discr[0]) + # sometime the evaluation is performed over arrays of type object. + # hence, `result` might be of type object, which don't work well + # with numpy real and imag functions. + r = r.astype(complex) + results.append(r) + + if cast_to_real: + discr = [np.real(d.astype(complex)) for d in discr] + return [*discr, *results] + + def _aggregate_args(self): + """Create a list of arguments to be passed to the lambda function, + sorted according to self._signature. + """ + args = [] + for s in self._signature: + if s in self._params.keys(): + args.append( + int(self._params[s]) if s in self._needs_to_be_int else + self._params[s] if self._force_real_eval + else complex(self._params[s])) + else: + args.append(self._discretized_domain[s]) + return args + + @property + def expr(self): + """Return the expression (or expressions) of the series.""" + return self._expr + + @expr.setter + def expr(self, e): + """Set the expression (or expressions) of the series.""" + is_iter = hasattr(e, "__iter__") + is_callable = callable(e) if not is_iter else any(callable(t) for t in e) + if is_callable: + self._expr = e + else: + self._expr = sympify(e) if not is_iter else Tuple(*e) + + # look for the upper bound of summations and products + s = set() + for e in self._expr.atoms(Sum, Product): + for a in e.args[1:]: + if isinstance(a[-1], Symbol): + s.add(a[-1]) + self._needs_to_be_int = list(s) + + # list of sympy functions that when lambdified, the corresponding + # numpy functions don't like complex-type arguments + pf = [ceiling, floor, atan2, frac, zeta] + if self._force_real_eval is not True: + check_res = [self._expr.has(f) for f in pf] + self._force_real_eval = any(check_res) + if self._force_real_eval and ((self.modules is None) or + (isinstance(self.modules, str) and "numpy" in self.modules)): + funcs = [f for f, c in zip(pf, check_res) if c] + warnings.warn("NumPy is unable to evaluate with complex " + "numbers some of the functions included in this " + "symbolic expression: %s. " % funcs + + "Hence, the evaluation will use real numbers. " + "If you believe the resulting plot is incorrect, " + "change the evaluation module by setting the " + "`modules` keyword argument.") + if self._functions: + # update lambda functions + self._create_lambda_func() + + @property + def is_3D(self): + flags3D = [self.is_3Dline, self.is_3Dsurface, self.is_3Dvector] + return any(flags3D) + + @property + def is_line(self): + flagslines = [self.is_2Dline, self.is_3Dline] + return any(flagslines) + + def _line_surface_color(self, prop, val): + """This method enables back-compatibility with old sympy.plotting""" + # NOTE: color_func is set inside the init method of the series. + # If line_color/surface_color is not a callable, then color_func will + # be set to None. + setattr(self, prop, val) + if callable(val) or isinstance(val, Expr): + self.color_func = val + setattr(self, prop, None) + elif val is not None: + self.color_func = None + + @property + def line_color(self): + return self._line_color + + @line_color.setter + def line_color(self, val): + self._line_surface_color("_line_color", val) + + @property + def n(self): + """Returns a list [n1, n2, n3] of numbers of discratization points. + """ + return self._n + + @n.setter + def n(self, v): + """Set the numbers of discretization points. ``v`` must be an int or + a list. + + Let ``s`` be a series. Then: + + * to set the number of discretization points along the x direction (or + first parameter): ``s.n = 10`` + * to set the number of discretization points along the x and y + directions (or first and second parameters): ``s.n = [10, 15]`` + * to set the number of discretization points along the x, y and z + directions: ``s.n = [10, 15, 20]`` + + The following is highly unreccomended, because it prevents + the execution of necessary code in order to keep updated data: + ``s.n[1] = 15`` + """ + if not hasattr(v, "__iter__"): + self._n[0] = v + else: + self._n[:len(v)] = v + if self._discretized_domain: + # update the discretized domain + self._create_discretized_domain() + + @property + def params(self): + """Get or set the current parameters dictionary. + + Parameters + ========== + + p : dict + + * key: symbol associated to the parameter + * val: the numeric value + """ + return self._params + + @params.setter + def params(self, p): + self._params = p + + def _post_init(self): + exprs = self.expr if hasattr(self.expr, "__iter__") else [self.expr] + if any(callable(e) for e in exprs) and self.params: + raise TypeError("`params` was provided, hence an interactive plot " + "is expected. However, interactive plots do not support " + "user-provided numerical functions.") + + # if the expressions is a lambda function and no label has been + # provided, then its better to do the following in order to avoid + # surprises on the backend + if any(callable(e) for e in exprs): + if self._label == str(self.expr): + self.label = "" + + self._check_fs() + + if hasattr(self, "adaptive") and self.adaptive and self.params: + warnings.warn("`params` was provided, hence an interactive plot " + "is expected. However, interactive plots do not support " + "adaptive evaluation. Automatically switched to " + "adaptive=False.") + self.adaptive = False + + @property + def scales(self): + return self._scales + + @scales.setter + def scales(self, v): + if isinstance(v, str): + self._scales[0] = v + else: + self._scales[:len(v)] = v + + @property + def surface_color(self): + return self._surface_color + + @surface_color.setter + def surface_color(self, val): + self._line_surface_color("_surface_color", val) + + @property + def rendering_kw(self): + return self._rendering_kw + + @rendering_kw.setter + def rendering_kw(self, kwargs): + if isinstance(kwargs, dict): + self._rendering_kw = kwargs + else: + self._rendering_kw = {} + if kwargs is not None: + warnings.warn( + "`rendering_kw` must be a dictionary, instead an " + "object of type %s was received. " % type(kwargs) + + "Automatically setting `rendering_kw` to an empty " + "dictionary") + + @staticmethod + def _discretize(start, end, N, scale="linear", only_integers=False): + """Discretize a 1D domain. + + Returns + ======= + + domain : np.ndarray with dtype=float or complex + The domain's dtype will be float or complex (depending on the + type of start/end) even if only_integers=True. It is left for + the downstream code to perform further casting, if necessary. + """ + np = import_module('numpy') + + if only_integers is True: + start, end = int(start), int(end) + N = end - start + 1 + + if scale == "linear": + return np.linspace(start, end, N) + return np.geomspace(start, end, N) + + @staticmethod + def _correct_shape(a, b): + """Convert ``a`` to a np.ndarray of the same shape of ``b``. + + Parameters + ========== + + a : int, float, complex, np.ndarray + Usually, this is the result of a numerical evaluation of a + symbolic expression. Even if a discretized domain was used to + evaluate the function, the result can be a scalar (int, float, + complex). Think for example to ``expr = Float(2)`` and + ``f = lambdify(x, expr)``. No matter the shape of the numerical + array representing x, the result of the evaluation will be + a single value. + + b : np.ndarray + It represents the correct shape that ``a`` should have. + + Returns + ======= + new_a : np.ndarray + An array with the correct shape. + """ + np = import_module('numpy') + + if not isinstance(a, np.ndarray): + a = np.array(a) + if a.shape != b.shape: + if a.shape == (): + a = a * np.ones_like(b) + else: + a = a.reshape(b.shape) + return a + + def eval_color_func(self, *args): + """Evaluate the color function. + + Parameters + ========== + + args : tuple + Arguments to be passed to the coloring function. Can be coordinates + or parameters or both. + + Notes + ===== + + The backend will request the data series to generate the numerical + data. Depending on the data series, either the data series itself or + the backend will eventually execute this function to generate the + appropriate coloring value. + """ + np = import_module('numpy') + if self.color_func is None: + # NOTE: with the line_color and surface_color attributes + # (back-compatibility with the old sympy.plotting module) it is + # possible to create a plot with a callable line_color (or + # surface_color). For example: + # p = plot(sin(x), line_color=lambda x, y: -y) + # This creates a ColoredLineOver1DRangeSeries with line_color=None + # and color_func=lambda x, y: -y, which effectively is a + # parametric series. Later we could change it to a string value: + # p[0].line_color = "red" + # However, this sets ine_color="red" and color_func=None, but the + # series is still ColoredLineOver1DRangeSeries (a parametric + # series), which will render using a color_func... + warnings.warn("This is likely not the result you were " + "looking for. Please, re-execute the plot command, this time " + "with the appropriate an appropriate value to line_color " + "or surface_color.") + return np.ones_like(args[0]) + + if self._eval_color_func_with_signature: + args = self._aggregate_args() + color = self.color_func(*args) + _re, _im = np.real(color), np.imag(color) + _re[np.invert(np.isclose(_im, np.zeros_like(_im)))] = np.nan + return _re + + nargs = arity(self.color_func) + if nargs == 1: + if self.is_2Dline and self.is_parametric: + if len(args) == 2: + # ColoredLineOver1DRangeSeries + return self._correct_shape(self.color_func(args[0]), args[0]) + # Parametric2DLineSeries + return self._correct_shape(self.color_func(args[2]), args[2]) + elif self.is_3Dline and self.is_parametric: + return self._correct_shape(self.color_func(args[3]), args[3]) + elif self.is_3Dsurface and self.is_parametric: + return self._correct_shape(self.color_func(args[3]), args[3]) + return self._correct_shape(self.color_func(args[0]), args[0]) + elif nargs == 2: + if self.is_3Dsurface and self.is_parametric: + return self._correct_shape(self.color_func(*args[3:]), args[3]) + return self._correct_shape(self.color_func(*args[:2]), args[0]) + return self._correct_shape(self.color_func(*args[:nargs]), args[0]) + + def get_data(self): + """Compute and returns the numerical data. + + The number of parameters returned by this method depends on the + specific instance. If ``s`` is the series, make sure to read + ``help(s.get_data)`` to understand what it returns. + """ + raise NotImplementedError + + def _get_wrapped_label(self, label, wrapper): + """Given a latex representation of an expression, wrap it inside + some characters. Matplotlib needs "$%s%$", K3D-Jupyter needs "%s". + """ + return wrapper % label + + def get_label(self, use_latex=False, wrapper="$%s$"): + """Return the label to be used to display the expression. + + Parameters + ========== + use_latex : bool + If False, the string representation of the expression is returned. + If True, the latex representation is returned. + wrapper : str + The backend might need the latex representation to be wrapped by + some characters. Default to ``"$%s$"``. + + Returns + ======= + label : str + """ + if use_latex is False: + return self._label + if self._label == str(self.expr): + # when the backend requests a latex label and user didn't provide + # any label + return self._get_wrapped_label(self._latex_label, wrapper) + return self._latex_label + + @property + def label(self): + return self.get_label() + + @label.setter + def label(self, val): + """Set the labels associated to this series.""" + # NOTE: the init method of any series requires a label. If the user do + # not provide it, the preprocessing function will set label=None, which + # informs the series to initialize two attributes: + # _label contains the string representation of the expression. + # _latex_label contains the latex representation of the expression. + self._label = self._latex_label = val + + @property + def ranges(self): + return self._ranges + + @ranges.setter + def ranges(self, val): + new_vals = [] + for v in val: + if v is not None: + new_vals.append(tuple([sympify(t) for t in v])) + self._ranges = new_vals + + def _apply_transform(self, *args): + """Apply transformations to the results of numerical evaluation. + + Parameters + ========== + args : tuple + Results of numerical evaluation. + + Returns + ======= + transformed_args : tuple + Tuple containing the transformed results. + """ + t = lambda x, transform: x if transform is None else transform(x) + x, y, z = None, None, None + if len(args) == 2: + x, y = args + return t(x, self._tx), t(y, self._ty) + elif (len(args) == 3) and isinstance(self, Parametric2DLineSeries): + x, y, u = args + return (t(x, self._tx), t(y, self._ty), t(u, self._tp)) + elif len(args) == 3: + x, y, z = args + return t(x, self._tx), t(y, self._ty), t(z, self._tz) + elif (len(args) == 4) and isinstance(self, Parametric3DLineSeries): + x, y, z, u = args + return (t(x, self._tx), t(y, self._ty), t(z, self._tz), t(u, self._tp)) + elif len(args) == 4: # 2D vector plot + x, y, u, v = args + return ( + t(x, self._tx), t(y, self._ty), + t(u, self._tx), t(v, self._ty) + ) + elif (len(args) == 5) and isinstance(self, ParametricSurfaceSeries): + x, y, z, u, v = args + return (t(x, self._tx), t(y, self._ty), t(z, self._tz), u, v) + elif (len(args) == 6) and self.is_3Dvector: # 3D vector plot + x, y, z, u, v, w = args + return ( + t(x, self._tx), t(y, self._ty), t(z, self._tz), + t(u, self._tx), t(v, self._ty), t(w, self._tz) + ) + elif len(args) == 6: # complex plot + x, y, _abs, _arg, img, colors = args + return ( + x, y, t(_abs, self._tz), _arg, img, colors) + return args + + def _str_helper(self, s): + pre, post = "", "" + if self.is_interactive: + pre = "interactive " + post = " and parameters " + str(tuple(self.params.keys())) + return pre + s + post + + +def _detect_poles_numerical_helper(x, y, eps=0.01, expr=None, symb=None, symbolic=False): + """Compute the steepness of each segment. If it's greater than a + threshold, set the right-point y-value non NaN and record the + corresponding x-location for further processing. + + Returns + ======= + x : np.ndarray + Unchanged x-data. + yy : np.ndarray + Modified y-data with NaN values. + """ + np = import_module('numpy') + + yy = y.copy() + threshold = np.pi / 2 - eps + for i in range(len(x) - 1): + dx = x[i + 1] - x[i] + dy = abs(y[i + 1] - y[i]) + angle = np.arctan(dy / dx) + if abs(angle) >= threshold: + yy[i + 1] = np.nan + + return x, yy + +def _detect_poles_symbolic_helper(expr, symb, start, end): + """Attempts to compute symbolic discontinuities. + + Returns + ======= + pole : list + List of symbolic poles, possibly empty. + """ + poles = [] + interval = Interval(nsimplify(start), nsimplify(end)) + res = continuous_domain(expr, symb, interval) + res = res.simplify() + if res == interval: + pass + elif (isinstance(res, Union) and + all(isinstance(t, Interval) for t in res.args)): + poles = [] + for s in res.args: + if s.left_open: + poles.append(s.left) + if s.right_open: + poles.append(s.right) + poles = list(set(poles)) + else: + raise ValueError( + f"Could not parse the following object: {res} .\n" + "Please, submit this as a bug. Consider also to set " + "`detect_poles=True`." + ) + return poles + + +### 2D lines +class Line2DBaseSeries(BaseSeries): + """A base class for 2D lines. + + - adding the label, steps and only_integers options + - making is_2Dline true + - defining get_segments and get_color_array + """ + + is_2Dline = True + _dim = 2 + _N = 1000 + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.steps = kwargs.get("steps", False) + self.is_point = kwargs.get("is_point", kwargs.get("point", False)) + self.is_filled = kwargs.get("is_filled", kwargs.get("fill", True)) + self.adaptive = kwargs.get("adaptive", False) + self.depth = kwargs.get('depth', 12) + self.use_cm = kwargs.get("use_cm", False) + self.color_func = kwargs.get("color_func", None) + self.line_color = kwargs.get("line_color", None) + self.detect_poles = kwargs.get("detect_poles", False) + self.eps = kwargs.get("eps", 0.01) + self.is_polar = kwargs.get("is_polar", kwargs.get("polar", False)) + self.unwrap = kwargs.get("unwrap", False) + # when detect_poles="symbolic", stores the location of poles so that + # they can be appropriately rendered + self.poles_locations = [] + exclude = kwargs.get("exclude", []) + if isinstance(exclude, Set): + exclude = list(extract_solution(exclude, n=100)) + if not hasattr(exclude, "__iter__"): + exclude = [exclude] + exclude = [float(e) for e in exclude] + self.exclude = sorted(exclude) + + def get_data(self): + """Return coordinates for plotting the line. + + Returns + ======= + + x: np.ndarray + x-coordinates + + y: np.ndarray + y-coordinates + + z: np.ndarray (optional) + z-coordinates in case of Parametric3DLineSeries, + Parametric3DLineInteractiveSeries + + param : np.ndarray (optional) + The parameter in case of Parametric2DLineSeries, + Parametric3DLineSeries or AbsArgLineSeries (and their + corresponding interactive series). + """ + np = import_module('numpy') + points = self._get_data_helper() + + if (isinstance(self, LineOver1DRangeSeries) and + (self.detect_poles == "symbolic")): + poles = _detect_poles_symbolic_helper( + self.expr.subs(self.params), *self.ranges[0]) + poles = np.array([float(t) for t in poles]) + t = lambda x, transform: x if transform is None else transform(x) + self.poles_locations = t(np.array(poles), self._tx) + + # postprocessing + points = self._apply_transform(*points) + + if self.is_2Dline and self.detect_poles: + if len(points) == 2: + x, y = points + x, y = _detect_poles_numerical_helper( + x, y, self.eps) + points = (x, y) + else: + x, y, p = points + x, y = _detect_poles_numerical_helper(x, y, self.eps) + points = (x, y, p) + + if self.unwrap: + kw = {} + if self.unwrap is not True: + kw = self.unwrap + if self.is_2Dline: + if len(points) == 2: + x, y = points + y = np.unwrap(y, **kw) + points = (x, y) + else: + x, y, p = points + y = np.unwrap(y, **kw) + points = (x, y, p) + + if self.steps is True: + if self.is_2Dline: + x, y = points[0], points[1] + x = np.array((x, x)).T.flatten()[1:] + y = np.array((y, y)).T.flatten()[:-1] + if self.is_parametric: + points = (x, y, points[2]) + else: + points = (x, y) + elif self.is_3Dline: + x = np.repeat(points[0], 3)[2:] + y = np.repeat(points[1], 3)[:-2] + z = np.repeat(points[2], 3)[1:-1] + if len(points) > 3: + points = (x, y, z, points[3]) + else: + points = (x, y, z) + + if len(self.exclude) > 0: + points = self._insert_exclusions(points) + return points + + def get_segments(self): + sympy_deprecation_warning( + """ + The Line2DBaseSeries.get_segments() method is deprecated. + + Instead, use the MatplotlibBackend.get_segments() method, or use + The get_points() or get_data() methods. + """, + deprecated_since_version="1.9", + active_deprecations_target="deprecated-get-segments") + + np = import_module('numpy') + points = type(self).get_data(self) + points = np.ma.array(points).T.reshape(-1, 1, self._dim) + return np.ma.concatenate([points[:-1], points[1:]], axis=1) + + def _insert_exclusions(self, points): + """Add NaN to each of the exclusion point. Practically, this adds a + NaN to the exclusion point, plus two other nearby points evaluated with + the numerical functions associated to this data series. + These nearby points are important when the number of discretization + points is low, or the scale is logarithm. + + NOTE: it would be easier to just add exclusion points to the + discretized domain before evaluation, then after evaluation add NaN + to the exclusion points. But that's only work with adaptive=False. + The following approach work even with adaptive=True. + """ + np = import_module("numpy") + points = list(points) + n = len(points) + # index of the x-coordinate (for 2d plots) or parameter (for 2d/3d + # parametric plots) + k = n - 1 + if n == 2: + k = 0 + # indices of the other coordinates + j_indeces = sorted(set(range(n)).difference([k])) + # TODO: for now, I assume that numpy functions are going to succeed + funcs = [f[0] for f in self._functions] + + for e in self.exclude: + res = points[k] - e >= 0 + # if res contains both True and False, ie, if e is found + if any(res) and any(~res): + idx = np.nanargmax(res) + # select the previous point with respect to e + idx -= 1 + # TODO: what if points[k][idx]==e or points[k][idx+1]==e? + + if idx > 0 and idx < len(points[k]) - 1: + delta_prev = abs(e - points[k][idx]) + delta_post = abs(e - points[k][idx + 1]) + delta = min(delta_prev, delta_post) / 100 + prev = e - delta + post = e + delta + + # add points to the x-coord or the parameter + points[k] = np.concatenate( + (points[k][:idx], [prev, e, post], points[k][idx+1:])) + + # add points to the other coordinates + c = 0 + for j in j_indeces: + values = funcs[c](np.array([prev, post])) + c += 1 + points[j] = np.concatenate( + (points[j][:idx], [values[0], np.nan, values[1]], points[j][idx+1:])) + return points + + @property + def var(self): + return None if not self.ranges else self.ranges[0][0] + + @property + def start(self): + if not self.ranges: + return None + try: + return self._cast(self.ranges[0][1]) + except TypeError: + return self.ranges[0][1] + + @property + def end(self): + if not self.ranges: + return None + try: + return self._cast(self.ranges[0][2]) + except TypeError: + return self.ranges[0][2] + + @property + def xscale(self): + return self._scales[0] + + @xscale.setter + def xscale(self, v): + self.scales = v + + def get_color_array(self): + np = import_module('numpy') + c = self.line_color + if hasattr(c, '__call__'): + f = np.vectorize(c) + nargs = arity(c) + if nargs == 1 and self.is_parametric: + x = self.get_parameter_points() + return f(centers_of_segments(x)) + else: + variables = list(map(centers_of_segments, self.get_points())) + if nargs == 1: + return f(variables[0]) + elif nargs == 2: + return f(*variables[:2]) + else: # only if the line is 3D (otherwise raises an error) + return f(*variables) + else: + return c*np.ones(self.nb_of_points) + + +class List2DSeries(Line2DBaseSeries): + """Representation for a line consisting of list of points.""" + + def __init__(self, list_x, list_y, label="", **kwargs): + super().__init__(**kwargs) + np = import_module('numpy') + if len(list_x) != len(list_y): + raise ValueError( + "The two lists of coordinates must have the same " + "number of elements.\n" + "Received: len(list_x) = {} ".format(len(list_x)) + + "and len(list_y) = {}".format(len(list_y)) + ) + self._block_lambda_functions(list_x, list_y) + check = lambda l: [isinstance(t, Expr) and (not t.is_number) for t in l] + if any(check(list_x) + check(list_y)) or self.params: + if not self.params: + raise ValueError("Some or all elements of the provided lists " + "are symbolic expressions, but the ``params`` dictionary " + "was not provided: those elements can't be evaluated.") + self.list_x = Tuple(*list_x) + self.list_y = Tuple(*list_y) + else: + self.list_x = np.array(list_x, dtype=np.float64) + self.list_y = np.array(list_y, dtype=np.float64) + + self._expr = (self.list_x, self.list_y) + if not any(isinstance(t, np.ndarray) for t in [self.list_x, self.list_y]): + self._check_fs() + self.is_polar = kwargs.get("is_polar", kwargs.get("polar", False)) + self.label = label + self.rendering_kw = kwargs.get("rendering_kw", {}) + if self.use_cm and self.color_func: + self.is_parametric = True + if isinstance(self.color_func, Expr): + raise TypeError( + "%s don't support symbolic " % self.__class__.__name__ + + "expression for `color_func`.") + + def __str__(self): + return "2D list plot" + + def _get_data_helper(self): + """Returns coordinates that needs to be postprocessed.""" + lx, ly = self.list_x, self.list_y + + if not self.is_interactive: + return self._eval_color_func_and_return(lx, ly) + + np = import_module('numpy') + lx = np.array([t.evalf(subs=self.params) for t in lx], dtype=float) + ly = np.array([t.evalf(subs=self.params) for t in ly], dtype=float) + return self._eval_color_func_and_return(lx, ly) + + def _eval_color_func_and_return(self, *data): + if self.use_cm and callable(self.color_func): + return [*data, self.eval_color_func(*data)] + return data + + +class LineOver1DRangeSeries(Line2DBaseSeries): + """Representation for a line consisting of a SymPy expression over a range.""" + + def __init__(self, expr, var_start_end, label="", **kwargs): + super().__init__(**kwargs) + self.expr = expr if callable(expr) else sympify(expr) + self._label = str(self.expr) if label is None else label + self._latex_label = latex(self.expr) if label is None else label + self.ranges = [var_start_end] + self._cast = complex + # for complex-related data series, this determines what data to return + # on the y-axis + self._return = kwargs.get("return", None) + self._post_init() + + if not self._interactive_ranges: + # NOTE: the following check is only possible when the minimum and + # maximum values of a plotting range are numeric + start, end = [complex(t) for t in self.ranges[0][1:]] + if im(start) != im(end): + raise ValueError( + "%s requires the imaginary " % self.__class__.__name__ + + "part of the start and end values of the range " + "to be the same.") + + if self.adaptive and self._return: + warnings.warn("The adaptive algorithm is unable to deal with " + "complex numbers. Automatically switching to uniform meshing.") + self.adaptive = False + + @property + def nb_of_points(self): + return self.n[0] + + @nb_of_points.setter + def nb_of_points(self, v): + self.n = v + + def __str__(self): + def f(t): + if isinstance(t, complex): + if t.imag != 0: + return t + return t.real + return t + pre = "interactive " if self.is_interactive else "" + post = "" + if self.is_interactive: + post = " and parameters " + str(tuple(self.params.keys())) + wrapper = _get_wrapper_for_expr(self._return) + return pre + "cartesian line: %s for %s over %s" % ( + wrapper % self.expr, + str(self.var), + str((f(self.start), f(self.end))), + ) + post + + def get_points(self): + """Return lists of coordinates for plotting. Depending on the + ``adaptive`` option, this function will either use an adaptive algorithm + or it will uniformly sample the expression over the provided range. + + This function is available for back-compatibility purposes. Consider + using ``get_data()`` instead. + + Returns + ======= + x : list + List of x-coordinates + + y : list + List of y-coordinates + """ + return self._get_data_helper() + + def _adaptive_sampling(self): + try: + if callable(self.expr): + f = self.expr + else: + f = lambdify([self.var], self.expr, self.modules) + x, y = self._adaptive_sampling_helper(f) + except Exception as err: + warnings.warn( + "The evaluation with %s failed.\n" % ( + "NumPy/SciPy" if not self.modules else self.modules) + + "{}: {}\n".format(type(err).__name__, err) + + "Trying to evaluate the expression with Sympy, but it might " + "be a slow operation." + ) + f = lambdify([self.var], self.expr, "sympy") + x, y = self._adaptive_sampling_helper(f) + return x, y + + def _adaptive_sampling_helper(self, f): + """The adaptive sampling is done by recursively checking if three + points are almost collinear. If they are not collinear, then more + points are added between those points. + + References + ========== + + .. [1] Adaptive polygonal approximation of parametric curves, + Luiz Henrique de Figueiredo. + """ + np = import_module('numpy') + + x_coords = [] + y_coords = [] + def sample(p, q, depth): + """ Samples recursively if three points are almost collinear. + For depth < 6, points are added irrespective of whether they + satisfy the collinearity condition or not. The maximum depth + allowed is 12. + """ + # Randomly sample to avoid aliasing. + random = 0.45 + np.random.rand() * 0.1 + if self.xscale == 'log': + xnew = 10**(np.log10(p[0]) + random * (np.log10(q[0]) - + np.log10(p[0]))) + else: + xnew = p[0] + random * (q[0] - p[0]) + ynew = _adaptive_eval(f, xnew) + new_point = np.array([xnew, ynew]) + + # Maximum depth + if depth > self.depth: + x_coords.append(q[0]) + y_coords.append(q[1]) + + # Sample to depth of 6 (whether the line is flat or not) + # without using linspace (to avoid aliasing). + elif depth < 6: + sample(p, new_point, depth + 1) + sample(new_point, q, depth + 1) + + # Sample ten points if complex values are encountered + # at both ends. If there is a real value in between, then + # sample those points further. + elif p[1] is None and q[1] is None: + if self.xscale == 'log': + xarray = np.logspace(p[0], q[0], 10) + else: + xarray = np.linspace(p[0], q[0], 10) + yarray = list(map(f, xarray)) + if not all(y is None for y in yarray): + for i in range(len(yarray) - 1): + if not (yarray[i] is None and yarray[i + 1] is None): + sample([xarray[i], yarray[i]], + [xarray[i + 1], yarray[i + 1]], depth + 1) + + # Sample further if one of the end points in None (i.e. a + # complex value) or the three points are not almost collinear. + elif (p[1] is None or q[1] is None or new_point[1] is None + or not flat(p, new_point, q)): + sample(p, new_point, depth + 1) + sample(new_point, q, depth + 1) + else: + x_coords.append(q[0]) + y_coords.append(q[1]) + + f_start = _adaptive_eval(f, self.start.real) + f_end = _adaptive_eval(f, self.end.real) + x_coords.append(self.start.real) + y_coords.append(f_start) + sample(np.array([self.start.real, f_start]), + np.array([self.end.real, f_end]), 0) + + return (x_coords, y_coords) + + def _uniform_sampling(self): + np = import_module('numpy') + + x, result = self._evaluate() + _re, _im = np.real(result), np.imag(result) + _re = self._correct_shape(_re, x) + _im = self._correct_shape(_im, x) + return x, _re, _im + + def _get_data_helper(self): + """Returns coordinates that needs to be postprocessed. + """ + np = import_module('numpy') + if self.adaptive and (not self.only_integers): + x, y = self._adaptive_sampling() + return [np.array(t) for t in [x, y]] + + x, _re, _im = self._uniform_sampling() + + if self._return is None: + # The evaluation could produce complex numbers. Set real elements + # to NaN where there are non-zero imaginary elements + _re[np.invert(np.isclose(_im, np.zeros_like(_im)))] = np.nan + elif self._return == "real": + pass + elif self._return == "imag": + _re = _im + elif self._return == "abs": + _re = np.sqrt(_re**2 + _im**2) + elif self._return == "arg": + _re = np.arctan2(_im, _re) + else: + raise ValueError("`_return` not recognized. " + "Received: %s" % self._return) + + return x, _re + + +class ParametricLineBaseSeries(Line2DBaseSeries): + is_parametric = True + + def _set_parametric_line_label(self, label): + """Logic to set the correct label to be shown on the plot. + If `use_cm=True` there will be a colorbar, so we show the parameter. + If `use_cm=False`, there might be a legend, so we show the expressions. + + Parameters + ========== + label : str + label passed in by the pre-processor or the user + """ + self._label = str(self.var) if label is None else label + self._latex_label = latex(self.var) if label is None else label + if (self.use_cm is False) and (self._label == str(self.var)): + self._label = str(self.expr) + self._latex_label = latex(self.expr) + # if the expressions is a lambda function and use_cm=False and no label + # has been provided, then its better to do the following in order to + # avoid surprises on the backend + if any(callable(e) for e in self.expr) and (not self.use_cm): + if self._label == str(self.expr): + self._label = "" + + def get_label(self, use_latex=False, wrapper="$%s$"): + # parametric lines returns the representation of the parameter to be + # shown on the colorbar if `use_cm=True`, otherwise it returns the + # representation of the expression to be placed on the legend. + if self.use_cm: + if str(self.var) == self._label: + if use_latex: + return self._get_wrapped_label(latex(self.var), wrapper) + return str(self.var) + # here the user has provided a custom label + return self._label + if use_latex: + if self._label != str(self.expr): + return self._latex_label + return self._get_wrapped_label(self._latex_label, wrapper) + return self._label + + def _get_data_helper(self): + """Returns coordinates that needs to be postprocessed. + Depending on the `adaptive` option, this function will either use an + adaptive algorithm or it will uniformly sample the expression over the + provided range. + """ + if self.adaptive: + np = import_module("numpy") + coords = self._adaptive_sampling() + coords = [np.array(t) for t in coords] + else: + coords = self._uniform_sampling() + + if self.is_2Dline and self.is_polar: + # when plot_polar is executed with polar_axis=True + np = import_module('numpy') + x, y, _ = coords + r = np.sqrt(x**2 + y**2) + t = np.arctan2(y, x) + coords = [t, r, coords[-1]] + + if callable(self.color_func): + coords = list(coords) + coords[-1] = self.eval_color_func(*coords) + + return coords + + def _uniform_sampling(self): + """Returns coordinates that needs to be postprocessed.""" + np = import_module('numpy') + + results = self._evaluate() + for i, r in enumerate(results): + _re, _im = np.real(r), np.imag(r) + _re[np.invert(np.isclose(_im, np.zeros_like(_im)))] = np.nan + results[i] = _re + + return [*results[1:], results[0]] + + def get_parameter_points(self): + return self.get_data()[-1] + + def get_points(self): + """ Return lists of coordinates for plotting. Depending on the + ``adaptive`` option, this function will either use an adaptive algorithm + or it will uniformly sample the expression over the provided range. + + This function is available for back-compatibility purposes. Consider + using ``get_data()`` instead. + + Returns + ======= + x : list + List of x-coordinates + y : list + List of y-coordinates + z : list + List of z-coordinates, only for 3D parametric line plot. + """ + return self._get_data_helper()[:-1] + + @property + def nb_of_points(self): + return self.n[0] + + @nb_of_points.setter + def nb_of_points(self, v): + self.n = v + + +class Parametric2DLineSeries(ParametricLineBaseSeries): + """Representation for a line consisting of two parametric SymPy expressions + over a range.""" + + is_2Dline = True + + def __init__(self, expr_x, expr_y, var_start_end, label="", **kwargs): + super().__init__(**kwargs) + self.expr_x = expr_x if callable(expr_x) else sympify(expr_x) + self.expr_y = expr_y if callable(expr_y) else sympify(expr_y) + self.expr = (self.expr_x, self.expr_y) + self.ranges = [var_start_end] + self._cast = float + self.use_cm = kwargs.get("use_cm", True) + self._set_parametric_line_label(label) + self._post_init() + + def __str__(self): + return self._str_helper( + "parametric cartesian line: (%s, %s) for %s over %s" % ( + str(self.expr_x), + str(self.expr_y), + str(self.var), + str((self.start, self.end)) + )) + + def _adaptive_sampling(self): + try: + if callable(self.expr_x) and callable(self.expr_y): + f_x = self.expr_x + f_y = self.expr_y + else: + f_x = lambdify([self.var], self.expr_x) + f_y = lambdify([self.var], self.expr_y) + x, y, p = self._adaptive_sampling_helper(f_x, f_y) + except Exception as err: + warnings.warn( + "The evaluation with %s failed.\n" % ( + "NumPy/SciPy" if not self.modules else self.modules) + + "{}: {}\n".format(type(err).__name__, err) + + "Trying to evaluate the expression with Sympy, but it might " + "be a slow operation." + ) + f_x = lambdify([self.var], self.expr_x, "sympy") + f_y = lambdify([self.var], self.expr_y, "sympy") + x, y, p = self._adaptive_sampling_helper(f_x, f_y) + return x, y, p + + def _adaptive_sampling_helper(self, f_x, f_y): + """The adaptive sampling is done by recursively checking if three + points are almost collinear. If they are not collinear, then more + points are added between those points. + + References + ========== + + .. [1] Adaptive polygonal approximation of parametric curves, + Luiz Henrique de Figueiredo. + """ + x_coords = [] + y_coords = [] + param = [] + + def sample(param_p, param_q, p, q, depth): + """ Samples recursively if three points are almost collinear. + For depth < 6, points are added irrespective of whether they + satisfy the collinearity condition or not. The maximum depth + allowed is 12. + """ + # Randomly sample to avoid aliasing. + np = import_module('numpy') + random = 0.45 + np.random.rand() * 0.1 + param_new = param_p + random * (param_q - param_p) + xnew = _adaptive_eval(f_x, param_new) + ynew = _adaptive_eval(f_y, param_new) + new_point = np.array([xnew, ynew]) + + # Maximum depth + if depth > self.depth: + x_coords.append(q[0]) + y_coords.append(q[1]) + param.append(param_p) + + # Sample irrespective of whether the line is flat till the + # depth of 6. We are not using linspace to avoid aliasing. + elif depth < 6: + sample(param_p, param_new, p, new_point, depth + 1) + sample(param_new, param_q, new_point, q, depth + 1) + + # Sample ten points if complex values are encountered + # at both ends. If there is a real value in between, then + # sample those points further. + elif ((p[0] is None and q[1] is None) or + (p[1] is None and q[1] is None)): + param_array = np.linspace(param_p, param_q, 10) + x_array = [_adaptive_eval(f_x, t) for t in param_array] + y_array = [_adaptive_eval(f_y, t) for t in param_array] + if not all(x is None and y is None + for x, y in zip(x_array, y_array)): + for i in range(len(y_array) - 1): + if ((x_array[i] is not None and y_array[i] is not None) or + (x_array[i + 1] is not None and y_array[i + 1] is not None)): + point_a = [x_array[i], y_array[i]] + point_b = [x_array[i + 1], y_array[i + 1]] + sample(param_array[i], param_array[i], point_a, + point_b, depth + 1) + + # Sample further if one of the end points in None (i.e. a complex + # value) or the three points are not almost collinear. + elif (p[0] is None or p[1] is None + or q[1] is None or q[0] is None + or not flat(p, new_point, q)): + sample(param_p, param_new, p, new_point, depth + 1) + sample(param_new, param_q, new_point, q, depth + 1) + else: + x_coords.append(q[0]) + y_coords.append(q[1]) + param.append(param_p) + + f_start_x = _adaptive_eval(f_x, self.start) + f_start_y = _adaptive_eval(f_y, self.start) + start = [f_start_x, f_start_y] + f_end_x = _adaptive_eval(f_x, self.end) + f_end_y = _adaptive_eval(f_y, self.end) + end = [f_end_x, f_end_y] + x_coords.append(f_start_x) + y_coords.append(f_start_y) + param.append(self.start) + sample(self.start, self.end, start, end, 0) + + return x_coords, y_coords, param + + +### 3D lines +class Line3DBaseSeries(Line2DBaseSeries): + """A base class for 3D lines. + + Most of the stuff is derived from Line2DBaseSeries.""" + + is_2Dline = False + is_3Dline = True + _dim = 3 + + def __init__(self): + super().__init__() + + +class Parametric3DLineSeries(ParametricLineBaseSeries): + """Representation for a 3D line consisting of three parametric SymPy + expressions and a range.""" + + is_2Dline = False + is_3Dline = True + + def __init__(self, expr_x, expr_y, expr_z, var_start_end, label="", **kwargs): + super().__init__(**kwargs) + self.expr_x = expr_x if callable(expr_x) else sympify(expr_x) + self.expr_y = expr_y if callable(expr_y) else sympify(expr_y) + self.expr_z = expr_z if callable(expr_z) else sympify(expr_z) + self.expr = (self.expr_x, self.expr_y, self.expr_z) + self.ranges = [var_start_end] + self._cast = float + self.adaptive = False + self.use_cm = kwargs.get("use_cm", True) + self._set_parametric_line_label(label) + self._post_init() + # TODO: remove this + self._xlim = None + self._ylim = None + self._zlim = None + + def __str__(self): + return self._str_helper( + "3D parametric cartesian line: (%s, %s, %s) for %s over %s" % ( + str(self.expr_x), + str(self.expr_y), + str(self.expr_z), + str(self.var), + str((self.start, self.end)) + )) + + def get_data(self): + # TODO: remove this + np = import_module("numpy") + x, y, z, p = super().get_data() + self._xlim = (np.amin(x), np.amax(x)) + self._ylim = (np.amin(y), np.amax(y)) + self._zlim = (np.amin(z), np.amax(z)) + return x, y, z, p + + +### Surfaces +class SurfaceBaseSeries(BaseSeries): + """A base class for 3D surfaces.""" + + is_3Dsurface = True + + def __init__(self, *args, **kwargs): + super().__init__(**kwargs) + self.use_cm = kwargs.get("use_cm", False) + # NOTE: why should SurfaceOver2DRangeSeries support is polar? + # After all, the same result can be achieve with + # ParametricSurfaceSeries. For example: + # sin(r) for (r, 0, 2 * pi) and (theta, 0, pi/2) can be parameterized + # as (r * cos(theta), r * sin(theta), sin(t)) for (r, 0, 2 * pi) and + # (theta, 0, pi/2). + # Because it is faster to evaluate (important for interactive plots). + self.is_polar = kwargs.get("is_polar", kwargs.get("polar", False)) + self.surface_color = kwargs.get("surface_color", None) + self.color_func = kwargs.get("color_func", lambda x, y, z: z) + if callable(self.surface_color): + self.color_func = self.surface_color + self.surface_color = None + + def _set_surface_label(self, label): + exprs = self.expr + self._label = str(exprs) if label is None else label + self._latex_label = latex(exprs) if label is None else label + # if the expressions is a lambda function and no label + # has been provided, then its better to do the following to avoid + # surprises on the backend + is_lambda = (callable(exprs) if not hasattr(exprs, "__iter__") + else any(callable(e) for e in exprs)) + if is_lambda and (self._label == str(exprs)): + self._label = "" + self._latex_label = "" + + def get_color_array(self): + np = import_module('numpy') + c = self.surface_color + if isinstance(c, Callable): + f = np.vectorize(c) + nargs = arity(c) + if self.is_parametric: + variables = list(map(centers_of_faces, self.get_parameter_meshes())) + if nargs == 1: + return f(variables[0]) + elif nargs == 2: + return f(*variables) + variables = list(map(centers_of_faces, self.get_meshes())) + if nargs == 1: + return f(variables[0]) + elif nargs == 2: + return f(*variables[:2]) + else: + return f(*variables) + else: + if isinstance(self, SurfaceOver2DRangeSeries): + return c*np.ones(min(self.nb_of_points_x, self.nb_of_points_y)) + else: + return c*np.ones(min(self.nb_of_points_u, self.nb_of_points_v)) + + +class SurfaceOver2DRangeSeries(SurfaceBaseSeries): + """Representation for a 3D surface consisting of a SymPy expression and 2D + range.""" + + def __init__(self, expr, var_start_end_x, var_start_end_y, label="", **kwargs): + super().__init__(**kwargs) + self.expr = expr if callable(expr) else sympify(expr) + self.ranges = [var_start_end_x, var_start_end_y] + self._set_surface_label(label) + self._post_init() + # TODO: remove this + self._xlim = (self.start_x, self.end_x) + self._ylim = (self.start_y, self.end_y) + + @property + def var_x(self): + return self.ranges[0][0] + + @property + def var_y(self): + return self.ranges[1][0] + + @property + def start_x(self): + try: + return float(self.ranges[0][1]) + except TypeError: + return self.ranges[0][1] + + @property + def end_x(self): + try: + return float(self.ranges[0][2]) + except TypeError: + return self.ranges[0][2] + + @property + def start_y(self): + try: + return float(self.ranges[1][1]) + except TypeError: + return self.ranges[1][1] + + @property + def end_y(self): + try: + return float(self.ranges[1][2]) + except TypeError: + return self.ranges[1][2] + + @property + def nb_of_points_x(self): + return self.n[0] + + @nb_of_points_x.setter + def nb_of_points_x(self, v): + n = self.n + self.n = [v, n[1:]] + + @property + def nb_of_points_y(self): + return self.n[1] + + @nb_of_points_y.setter + def nb_of_points_y(self, v): + n = self.n + self.n = [n[0], v, n[2]] + + def __str__(self): + series_type = "cartesian surface" if self.is_3Dsurface else "contour" + return self._str_helper( + series_type + ": %s for" " %s over %s and %s over %s" % ( + str(self.expr), + str(self.var_x), str((self.start_x, self.end_x)), + str(self.var_y), str((self.start_y, self.end_y)), + )) + + def get_meshes(self): + """Return the x,y,z coordinates for plotting the surface. + This function is available for back-compatibility purposes. Consider + using ``get_data()`` instead. + """ + return self.get_data() + + def get_data(self): + """Return arrays of coordinates for plotting. + + Returns + ======= + mesh_x : np.ndarray + Discretized x-domain. + mesh_y : np.ndarray + Discretized y-domain. + mesh_z : np.ndarray + Results of the evaluation. + """ + np = import_module('numpy') + + results = self._evaluate() + # mask out complex values + for i, r in enumerate(results): + _re, _im = np.real(r), np.imag(r) + _re[np.invert(np.isclose(_im, np.zeros_like(_im)))] = np.nan + results[i] = _re + + x, y, z = results + if self.is_polar and self.is_3Dsurface: + r = x.copy() + x = r * np.cos(y) + y = r * np.sin(y) + + # TODO: remove this + self._zlim = (np.amin(z), np.amax(z)) + + return self._apply_transform(x, y, z) + + +class ParametricSurfaceSeries(SurfaceBaseSeries): + """Representation for a 3D surface consisting of three parametric SymPy + expressions and a range.""" + + is_parametric = True + + def __init__(self, expr_x, expr_y, expr_z, + var_start_end_u, var_start_end_v, label="", **kwargs): + super().__init__(**kwargs) + self.expr_x = expr_x if callable(expr_x) else sympify(expr_x) + self.expr_y = expr_y if callable(expr_y) else sympify(expr_y) + self.expr_z = expr_z if callable(expr_z) else sympify(expr_z) + self.expr = (self.expr_x, self.expr_y, self.expr_z) + self.ranges = [var_start_end_u, var_start_end_v] + self.color_func = kwargs.get("color_func", lambda x, y, z, u, v: z) + self._set_surface_label(label) + self._post_init() + + @property + def var_u(self): + return self.ranges[0][0] + + @property + def var_v(self): + return self.ranges[1][0] + + @property + def start_u(self): + try: + return float(self.ranges[0][1]) + except TypeError: + return self.ranges[0][1] + + @property + def end_u(self): + try: + return float(self.ranges[0][2]) + except TypeError: + return self.ranges[0][2] + + @property + def start_v(self): + try: + return float(self.ranges[1][1]) + except TypeError: + return self.ranges[1][1] + + @property + def end_v(self): + try: + return float(self.ranges[1][2]) + except TypeError: + return self.ranges[1][2] + + @property + def nb_of_points_u(self): + return self.n[0] + + @nb_of_points_u.setter + def nb_of_points_u(self, v): + n = self.n + self.n = [v, n[1:]] + + @property + def nb_of_points_v(self): + return self.n[1] + + @nb_of_points_v.setter + def nb_of_points_v(self, v): + n = self.n + self.n = [n[0], v, n[2]] + + def __str__(self): + return self._str_helper( + "parametric cartesian surface: (%s, %s, %s) for" + " %s over %s and %s over %s" % ( + str(self.expr_x), str(self.expr_y), str(self.expr_z), + str(self.var_u), str((self.start_u, self.end_u)), + str(self.var_v), str((self.start_v, self.end_v)), + )) + + def get_parameter_meshes(self): + return self.get_data()[3:] + + def get_meshes(self): + """Return the x,y,z coordinates for plotting the surface. + This function is available for back-compatibility purposes. Consider + using ``get_data()`` instead. + """ + return self.get_data()[:3] + + def get_data(self): + """Return arrays of coordinates for plotting. + + Returns + ======= + x : np.ndarray [n2 x n1] + x-coordinates. + y : np.ndarray [n2 x n1] + y-coordinates. + z : np.ndarray [n2 x n1] + z-coordinates. + mesh_u : np.ndarray [n2 x n1] + Discretized u range. + mesh_v : np.ndarray [n2 x n1] + Discretized v range. + """ + np = import_module('numpy') + + results = self._evaluate() + # mask out complex values + for i, r in enumerate(results): + _re, _im = np.real(r), np.imag(r) + _re[np.invert(np.isclose(_im, np.zeros_like(_im)))] = np.nan + results[i] = _re + + # TODO: remove this + x, y, z = results[2:] + self._xlim = (np.amin(x), np.amax(x)) + self._ylim = (np.amin(y), np.amax(y)) + self._zlim = (np.amin(z), np.amax(z)) + + return self._apply_transform(*results[2:], *results[:2]) + + +### Contours +class ContourSeries(SurfaceOver2DRangeSeries): + """Representation for a contour plot.""" + + is_3Dsurface = False + is_contour = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_filled = kwargs.get("is_filled", kwargs.get("fill", True)) + self.show_clabels = kwargs.get("clabels", True) + + # NOTE: contour plots are used by plot_contour, plot_vector and + # plot_complex_vector. By implementing contour_kw we are able to + # quickly target the contour plot. + self.rendering_kw = kwargs.get("contour_kw", + kwargs.get("rendering_kw", {})) + + +class GenericDataSeries(BaseSeries): + """Represents generic numerical data. + + Notes + ===== + This class serves the purpose of back-compatibility with the "markers, + annotations, fill, rectangles" keyword arguments that represent + user-provided numerical data. In particular, it solves the problem of + combining together two or more plot-objects with the ``extend`` or + ``append`` methods: user-provided numerical data is also taken into + consideration because it is stored in this series class. + + Also note that the current implementation is far from optimal, as each + keyword argument is stored into an attribute in the ``Plot`` class, which + requires a hard-coded if-statement in the ``MatplotlibBackend`` class. + The implementation suggests that it is ok to add attributes and + if-statements to provide more and more functionalities for user-provided + numerical data (e.g. adding horizontal lines, or vertical lines, or bar + plots, etc). However, in doing so one would reinvent the wheel: plotting + libraries (like Matplotlib) already implements the necessary API. + + Instead of adding more keyword arguments and attributes, users interested + in adding custom numerical data to a plot should retrieve the figure + created by this plotting module. For example, this code: + + .. plot:: + :context: close-figs + :include-source: True + + from sympy import Symbol, plot, cos + x = Symbol("x") + p = plot(cos(x), markers=[{"args": [[0, 1, 2], [0, 1, -1], "*"]}]) + + Becomes: + + .. plot:: + :context: close-figs + :include-source: True + + p = plot(cos(x), backend="matplotlib") + fig, ax = p._backend.fig, p._backend.ax + ax.plot([0, 1, 2], [0, 1, -1], "*") + fig + + Which is far better in terms of readability. Also, it gives access to the + full plotting library capabilities, without the need to reinvent the wheel. + """ + is_generic = True + + def __init__(self, tp, *args, **kwargs): + self.type = tp + self.args = args + self.rendering_kw = kwargs + + def get_data(self): + return self.args + + +class ImplicitSeries(BaseSeries): + """Representation for 2D Implicit plot.""" + + is_implicit = True + use_cm = False + _N = 100 + + def __init__(self, expr, var_start_end_x, var_start_end_y, label="", **kwargs): + super().__init__(**kwargs) + self.adaptive = kwargs.get("adaptive", False) + self.expr = expr + self._label = str(expr) if label is None else label + self._latex_label = latex(expr) if label is None else label + self.ranges = [var_start_end_x, var_start_end_y] + self.var_x, self.start_x, self.end_x = self.ranges[0] + self.var_y, self.start_y, self.end_y = self.ranges[1] + self._color = kwargs.get("color", kwargs.get("line_color", None)) + + if self.is_interactive and self.adaptive: + raise NotImplementedError("Interactive plot with `adaptive=True` " + "is not supported.") + + # Check whether the depth is greater than 4 or less than 0. + depth = kwargs.get("depth", 0) + if depth > 4: + depth = 4 + elif depth < 0: + depth = 0 + self.depth = 4 + depth + self._post_init() + + @property + def expr(self): + if self.adaptive: + return self._adaptive_expr + return self._non_adaptive_expr + + @expr.setter + def expr(self, expr): + self._block_lambda_functions(expr) + # these are needed for adaptive evaluation + expr, has_equality = self._has_equality(sympify(expr)) + self._adaptive_expr = expr + self.has_equality = has_equality + self._label = str(expr) + self._latex_label = latex(expr) + + if isinstance(expr, (BooleanFunction, Ne)) and (not self.adaptive): + self.adaptive = True + msg = "contains Boolean functions. " + if isinstance(expr, Ne): + msg = "is an unequality. " + warnings.warn( + "The provided expression " + msg + + "In order to plot the expression, the algorithm " + + "automatically switched to an adaptive sampling." + ) + + if isinstance(expr, BooleanFunction): + self._non_adaptive_expr = None + self._is_equality = False + else: + # these are needed for uniform meshing evaluation + expr, is_equality = self._preprocess_meshgrid_expression(expr, self.adaptive) + self._non_adaptive_expr = expr + self._is_equality = is_equality + + @property + def line_color(self): + return self._color + + @line_color.setter + def line_color(self, v): + self._color = v + + color = line_color + + def _has_equality(self, expr): + # Represents whether the expression contains an Equality, GreaterThan + # or LessThan + has_equality = False + + def arg_expand(bool_expr): + """Recursively expands the arguments of an Boolean Function""" + for arg in bool_expr.args: + if isinstance(arg, BooleanFunction): + arg_expand(arg) + elif isinstance(arg, Relational): + arg_list.append(arg) + + arg_list = [] + if isinstance(expr, BooleanFunction): + arg_expand(expr) + # Check whether there is an equality in the expression provided. + if any(isinstance(e, (Equality, GreaterThan, LessThan)) for e in arg_list): + has_equality = True + elif not isinstance(expr, Relational): + expr = Equality(expr, 0) + has_equality = True + elif isinstance(expr, (Equality, GreaterThan, LessThan)): + has_equality = True + + return expr, has_equality + + def __str__(self): + f = lambda t: float(t) if len(t.free_symbols) == 0 else t + + return self._str_helper( + "Implicit expression: %s for %s over %s and %s over %s") % ( + str(self._adaptive_expr), + str(self.var_x), + str((f(self.start_x), f(self.end_x))), + str(self.var_y), + str((f(self.start_y), f(self.end_y))), + ) + + def get_data(self): + """Returns numerical data. + + Returns + ======= + + If the series is evaluated with the `adaptive=True` it returns: + + interval_list : list + List of bounding rectangular intervals to be postprocessed and + eventually used with Matplotlib's ``fill`` command. + dummy : str + A string containing ``"fill"``. + + Otherwise, it returns 2D numpy arrays to be used with Matplotlib's + ``contour`` or ``contourf`` commands: + + x_array : np.ndarray + y_array : np.ndarray + z_array : np.ndarray + plot_type : str + A string specifying which plot command to use, ``"contour"`` + or ``"contourf"``. + """ + if self.adaptive: + data = self._adaptive_eval() + if data is not None: + return data + + return self._get_meshes_grid() + + def _adaptive_eval(self): + """ + References + ========== + + .. [1] Jeffrey Allen Tupper. Reliable Two-Dimensional Graphing Methods for + Mathematical Formulae with Two Free Variables. + + .. [2] Jeffrey Allen Tupper. Graphing Equations with Generalized Interval + Arithmetic. Master's thesis. University of Toronto, 1996 + """ + import sympy.plotting.intervalmath.lib_interval as li + + user_functions = {} + printer = IntervalMathPrinter({ + 'fully_qualified_modules': False, 'inline': True, + 'allow_unknown_functions': True, + 'user_functions': user_functions}) + + keys = [t for t in dir(li) if ("__" not in t) and (t not in ["import_module", "interval"])] + vals = [getattr(li, k) for k in keys] + d = dict(zip(keys, vals)) + func = lambdify((self.var_x, self.var_y), self.expr, modules=[d], printer=printer) + data = None + + try: + data = self._get_raster_interval(func) + except NameError as err: + warnings.warn( + "Adaptive meshing could not be applied to the" + " expression, as some functions are not yet implemented" + " in the interval math module:\n\n" + "NameError: %s\n\n" % err + + "Proceeding with uniform meshing." + ) + self.adaptive = False + except TypeError: + warnings.warn( + "Adaptive meshing could not be applied to the" + " expression. Using uniform meshing.") + self.adaptive = False + + return data + + def _get_raster_interval(self, func): + """Uses interval math to adaptively mesh and obtain the plot""" + np = import_module('numpy') + + k = self.depth + interval_list = [] + sx, sy = [float(t) for t in [self.start_x, self.start_y]] + ex, ey = [float(t) for t in [self.end_x, self.end_y]] + # Create initial 32 divisions + xsample = np.linspace(sx, ex, 33) + ysample = np.linspace(sy, ey, 33) + + # Add a small jitter so that there are no false positives for equality. + # Ex: y==x becomes True for x interval(1, 2) and y interval(1, 2) + # which will draw a rectangle. + jitterx = ( + (np.random.rand(len(xsample)) * 2 - 1) + * (ex - sx) + / 2 ** 20 + ) + jittery = ( + (np.random.rand(len(ysample)) * 2 - 1) + * (ey - sy) + / 2 ** 20 + ) + xsample += jitterx + ysample += jittery + + xinter = [interval(x1, x2) for x1, x2 in zip(xsample[:-1], xsample[1:])] + yinter = [interval(y1, y2) for y1, y2 in zip(ysample[:-1], ysample[1:])] + interval_list = [[x, y] for x in xinter for y in yinter] + plot_list = [] + + # recursive call refinepixels which subdivides the intervals which are + # neither True nor False according to the expression. + def refine_pixels(interval_list): + """Evaluates the intervals and subdivides the interval if the + expression is partially satisfied.""" + temp_interval_list = [] + plot_list = [] + for intervals in interval_list: + + # Convert the array indices to x and y values + intervalx = intervals[0] + intervaly = intervals[1] + func_eval = func(intervalx, intervaly) + # The expression is valid in the interval. Change the contour + # array values to 1. + if func_eval[1] is False or func_eval[0] is False: + pass + elif func_eval == (True, True): + plot_list.append([intervalx, intervaly]) + elif func_eval[1] is None or func_eval[0] is None: + # Subdivide + avgx = intervalx.mid + avgy = intervaly.mid + a = interval(intervalx.start, avgx) + b = interval(avgx, intervalx.end) + c = interval(intervaly.start, avgy) + d = interval(avgy, intervaly.end) + temp_interval_list.append([a, c]) + temp_interval_list.append([a, d]) + temp_interval_list.append([b, c]) + temp_interval_list.append([b, d]) + return temp_interval_list, plot_list + + while k >= 0 and len(interval_list): + interval_list, plot_list_temp = refine_pixels(interval_list) + plot_list.extend(plot_list_temp) + k = k - 1 + # Check whether the expression represents an equality + # If it represents an equality, then none of the intervals + # would have satisfied the expression due to floating point + # differences. Add all the undecided values to the plot. + if self.has_equality: + for intervals in interval_list: + intervalx = intervals[0] + intervaly = intervals[1] + func_eval = func(intervalx, intervaly) + if func_eval[1] and func_eval[0] is not False: + plot_list.append([intervalx, intervaly]) + return plot_list, "fill" + + def _get_meshes_grid(self): + """Generates the mesh for generating a contour. + + In the case of equality, ``contour`` function of matplotlib can + be used. In other cases, matplotlib's ``contourf`` is used. + """ + np = import_module('numpy') + + xarray, yarray, z_grid = self._evaluate() + _re, _im = np.real(z_grid), np.imag(z_grid) + _re[np.invert(np.isclose(_im, np.zeros_like(_im)))] = np.nan + if self._is_equality: + return xarray, yarray, _re, 'contour' + return xarray, yarray, _re, 'contourf' + + @staticmethod + def _preprocess_meshgrid_expression(expr, adaptive): + """If the expression is a Relational, rewrite it as a single + expression. + + Returns + ======= + + expr : Expr + The rewritten expression + + equality : Boolean + Whether the original expression was an Equality or not. + """ + equality = False + if isinstance(expr, Equality): + expr = expr.lhs - expr.rhs + equality = True + elif isinstance(expr, Relational): + expr = expr.gts - expr.lts + elif not adaptive: + raise NotImplementedError( + "The expression is not supported for " + "plotting in uniform meshed plot." + ) + return expr, equality + + def get_label(self, use_latex=False, wrapper="$%s$"): + """Return the label to be used to display the expression. + + Parameters + ========== + use_latex : bool + If False, the string representation of the expression is returned. + If True, the latex representation is returned. + wrapper : str + The backend might need the latex representation to be wrapped by + some characters. Default to ``"$%s$"``. + + Returns + ======= + label : str + """ + if use_latex is False: + return self._label + if self._label == str(self._adaptive_expr): + return self._get_wrapped_label(self._latex_label, wrapper) + return self._latex_label + + +############################################################################## +# Finding the centers of line segments or mesh faces +############################################################################## + +def centers_of_segments(array): + np = import_module('numpy') + return np.mean(np.vstack((array[:-1], array[1:])), 0) + + +def centers_of_faces(array): + np = import_module('numpy') + return np.mean(np.dstack((array[:-1, :-1], + array[1:, :-1], + array[:-1, 1:], + array[:-1, :-1], + )), 2) + + +def flat(x, y, z, eps=1e-3): + """Checks whether three points are almost collinear""" + np = import_module('numpy') + # Workaround plotting piecewise (#8577) + vector_a = (x - y).astype(float) + vector_b = (z - y).astype(float) + dot_product = np.dot(vector_a, vector_b) + vector_a_norm = np.linalg.norm(vector_a) + vector_b_norm = np.linalg.norm(vector_b) + cos_theta = dot_product / (vector_a_norm * vector_b_norm) + return abs(cos_theta + 1) < eps + + +def _set_discretization_points(kwargs, pt): + """Allow the use of the keyword arguments ``n, n1, n2`` to + specify the number of discretization points in one and two + directions, while keeping back-compatibility with older keyword arguments + like, ``nb_of_points, nb_of_points_*, points``. + + Parameters + ========== + + kwargs : dict + Dictionary of keyword arguments passed into a plotting function. + pt : type + The type of the series, which indicates the kind of plot we are + trying to create. + """ + replace_old_keywords = { + "nb_of_points": "n", + "nb_of_points_x": "n1", + "nb_of_points_y": "n2", + "nb_of_points_u": "n1", + "nb_of_points_v": "n2", + "points": "n" + } + for k, v in replace_old_keywords.items(): + if k in kwargs.keys(): + kwargs[v] = kwargs.pop(k) + + if pt in [LineOver1DRangeSeries, Parametric2DLineSeries, + Parametric3DLineSeries]: + if "n" in kwargs.keys(): + kwargs["n1"] = kwargs["n"] + if hasattr(kwargs["n"], "__iter__") and (len(kwargs["n"]) > 0): + kwargs["n1"] = kwargs["n"][0] + elif pt in [SurfaceOver2DRangeSeries, ContourSeries, + ParametricSurfaceSeries, ImplicitSeries]: + if "n" in kwargs.keys(): + if hasattr(kwargs["n"], "__iter__") and (len(kwargs["n"]) > 1): + kwargs["n1"] = kwargs["n"][0] + kwargs["n2"] = kwargs["n"][1] + else: + kwargs["n1"] = kwargs["n2"] = kwargs["n"] + return kwargs diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/textplot.py b/.venv/lib/python3.13/site-packages/sympy/plotting/textplot.py new file mode 100644 index 0000000000000000000000000000000000000000..5f1f2b639d6c387a6a36cf89fe36bc7717c92b2b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/textplot.py @@ -0,0 +1,168 @@ +from sympy.core.numbers import Float +from sympy.core.symbol import Dummy +from sympy.utilities.lambdify import lambdify + +import math + + +def is_valid(x): + """Check if a floating point number is valid""" + if x is None: + return False + if isinstance(x, complex): + return False + return not math.isinf(x) and not math.isnan(x) + + +def rescale(y, W, H, mi, ma): + """Rescale the given array `y` to fit into the integer values + between `0` and `H-1` for the values between ``mi`` and ``ma``. + """ + y_new = [] + + norm = ma - mi + offset = (ma + mi) / 2 + + for x in range(W): + if is_valid(y[x]): + normalized = (y[x] - offset) / norm + if not is_valid(normalized): + y_new.append(None) + else: + rescaled = Float((normalized*H + H/2) * (H-1)/H).round() + rescaled = int(rescaled) + y_new.append(rescaled) + else: + y_new.append(None) + return y_new + + +def linspace(start, stop, num): + return [start + (stop - start) * x / (num-1) for x in range(num)] + + +def textplot_str(expr, a, b, W=55, H=21): + """Generator for the lines of the plot""" + free = expr.free_symbols + if len(free) > 1: + raise ValueError( + "The expression must have a single variable. (Got {})" + .format(free)) + x = free.pop() if free else Dummy() + f = lambdify([x], expr) + if isinstance(a, complex): + if a.imag == 0: + a = a.real + if isinstance(b, complex): + if b.imag == 0: + b = b.real + a = float(a) + b = float(b) + + # Calculate function values + x = linspace(a, b, W) + y = [] + for val in x: + try: + y.append(f(val)) + # Not sure what exceptions to catch here or why... + except (ValueError, TypeError, ZeroDivisionError): + y.append(None) + + # Normalize height to screen space + y_valid = list(filter(is_valid, y)) + if y_valid: + ma = max(y_valid) + mi = min(y_valid) + if ma == mi: + if ma: + mi, ma = sorted([0, 2*ma]) + else: + mi, ma = -1, 1 + else: + mi, ma = -1, 1 + y_range = ma - mi + precision = math.floor(math.log10(y_range)) - 1 + precision *= -1 + mi = round(mi, precision) + ma = round(ma, precision) + y = rescale(y, W, H, mi, ma) + + y_bins = linspace(mi, ma, H) + + # Draw plot + margin = 7 + for h in range(H - 1, -1, -1): + s = [' '] * W + for i in range(W): + if y[i] == h: + if (i == 0 or y[i - 1] == h - 1) and (i == W - 1 or y[i + 1] == h + 1): + s[i] = '/' + elif (i == 0 or y[i - 1] == h + 1) and (i == W - 1 or y[i + 1] == h - 1): + s[i] = '\\' + else: + s[i] = '.' + + if h == 0: + for i in range(W): + s[i] = '_' + + # Print y values + if h in (0, H//2, H - 1): + prefix = ("%g" % y_bins[h]).rjust(margin)[:margin] + else: + prefix = " "*margin + s = "".join(s) + if h == H//2: + s = s.replace(" ", "-") + yield prefix + " |" + s + + # Print x values + bottom = " " * (margin + 2) + bottom += ("%g" % x[0]).ljust(W//2) + if W % 2 == 1: + bottom += ("%g" % x[W//2]).ljust(W//2) + else: + bottom += ("%g" % x[W//2]).ljust(W//2-1) + bottom += "%g" % x[-1] + yield bottom + + +def textplot(expr, a, b, W=55, H=21): + r""" + Print a crude ASCII art plot of the SymPy expression 'expr' (which + should contain a single symbol, e.g. x or something else) over the + interval [a, b]. + + Examples + ======== + + >>> from sympy import Symbol, sin + >>> from sympy.plotting import textplot + >>> t = Symbol('t') + >>> textplot(sin(t)*t, 0, 15) + 14 | ... + | . + | . + | . + | . + | ... + | / . . + | / + | / . + | . . . + 1.5 |----.......-------------------------------------------- + |.... \ . . + | \ / . + | .. / . + | \ / . + | .... + | . + | . . + | + | . . + -11 |_______________________________________________________ + 0 7.5 15 + """ + for line in textplot_str(expr, a, b, W, H): + print(line) diff --git a/.venv/lib/python3.13/site-packages/sympy/plotting/utils.py b/.venv/lib/python3.13/site-packages/sympy/plotting/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3213dea09b5a98e96094e7dffbd9b992c7d2b87e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/plotting/utils.py @@ -0,0 +1,323 @@ +from sympy.core.containers import Tuple +from sympy.core.basic import Basic +from sympy.core.expr import Expr +from sympy.core.function import AppliedUndef +from sympy.core.relational import Relational +from sympy.core.symbol import Dummy +from sympy.core.sympify import sympify +from sympy.logic.boolalg import BooleanFunction +from sympy.sets.fancysets import ImageSet +from sympy.sets.sets import FiniteSet +from sympy.tensor.indexed import Indexed + + +def _get_free_symbols(exprs): + """Returns the free symbols of a symbolic expression. + + If the expression contains any of these elements, assume that they are + the "free symbols" of the expression: + + * indexed objects + * applied undefined function (useful for sympy.physics.mechanics module) + """ + if not isinstance(exprs, (list, tuple, set)): + exprs = [exprs] + if all(callable(e) for e in exprs): + return set() + + free = set().union(*[e.atoms(Indexed) for e in exprs]) + free = free.union(*[e.atoms(AppliedUndef) for e in exprs]) + return free or set().union(*[e.free_symbols for e in exprs]) + + +def extract_solution(set_sol, n=10): + """Extract numerical solutions from a set solution (computed by solveset, + linsolve, nonlinsolve). Often, it is not trivial do get something useful + out of them. + + Parameters + ========== + + n : int, optional + In order to replace ImageSet with FiniteSet, an iterator is created + for each ImageSet contained in `set_sol`, starting from 0 up to `n`. + Default value: 10. + """ + images = set_sol.find(ImageSet) + for im in images: + it = iter(im) + s = FiniteSet(*[next(it) for n in range(0, n)]) + set_sol = set_sol.subs(im, s) + return set_sol + + +def _plot_sympify(args): + """This function recursively loop over the arguments passed to the plot + functions: the sympify function will be applied to all arguments except + those of type string/dict. + + Generally, users can provide the following arguments to a plot function: + + expr, range1 [tuple, opt], ..., label [str, opt], rendering_kw [dict, opt] + + `expr, range1, ...` can be sympified, whereas `label, rendering_kw` can't. + In particular, whenever a special character like $, {, }, ... is used in + the `label`, sympify will raise an error. + """ + if isinstance(args, Expr): + return args + + args = list(args) + for i, a in enumerate(args): + if isinstance(a, (list, tuple)): + args[i] = Tuple(*_plot_sympify(a), sympify=False) + elif not (isinstance(a, (str, dict)) or callable(a) + # NOTE: check if it is a vector from sympy.physics.vector module + # without importing the module (because it slows down SymPy's + # import process and triggers SymPy's optional-dependencies + # tests to fail). + or ((a.__class__.__name__ == "Vector") and not isinstance(a, Basic)) + ): + args[i] = sympify(a) + return args + + +def _create_ranges(exprs, ranges, npar, label="", params=None): + """This function does two things: + + 1. Check if the number of free symbols is in agreement with the type of + plot chosen. For example, plot() requires 1 free symbol; + plot3d() requires 2 free symbols. + 2. Sometime users create plots without providing ranges for the variables. + Here we create the necessary ranges. + + Parameters + ========== + + exprs : iterable + The expressions from which to extract the free symbols + ranges : iterable + The limiting ranges provided by the user + npar : int + The number of free symbols required by the plot functions. + For example, + npar=1 for plot, npar=2 for plot3d, ... + params : dict + A dictionary mapping symbols to parameters for interactive plot. + """ + get_default_range = lambda symbol: Tuple(symbol, -10, 10) + + free_symbols = _get_free_symbols(exprs) + if params is not None: + free_symbols = free_symbols.difference(params.keys()) + + if len(free_symbols) > npar: + raise ValueError( + "Too many free symbols.\n" + + "Expected {} free symbols.\n".format(npar) + + "Received {}: {}".format(len(free_symbols), free_symbols) + ) + + if len(ranges) > npar: + raise ValueError( + "Too many ranges. Received %s, expected %s" % (len(ranges), npar)) + + # free symbols in the ranges provided by the user + rfs = set().union([r[0] for r in ranges]) + if len(rfs) != len(ranges): + raise ValueError("Multiple ranges with the same symbol") + + if len(ranges) < npar: + symbols = free_symbols.difference(rfs) + if symbols != set(): + # add a range for each missing free symbols + for s in symbols: + ranges.append(get_default_range(s)) + # if there is still room, fill them with dummys + for i in range(npar - len(ranges)): + ranges.append(get_default_range(Dummy())) + + if len(free_symbols) == npar: + # there could be times when this condition is not met, for example + # plotting the function f(x, y) = x (which is a plane); in this case, + # free_symbols = {x} whereas rfs = {x, y} (or x and Dummy) + rfs = set().union([r[0] for r in ranges]) + if len(free_symbols.difference(rfs)) > 0: + raise ValueError( + "Incompatible free symbols of the expressions with " + "the ranges.\n" + + "Free symbols in the expressions: {}\n".format(free_symbols) + + "Free symbols in the ranges: {}".format(rfs) + ) + return ranges + + +def _is_range(r): + """A range is defined as (symbol, start, end). start and end should + be numbers. + """ + # TODO: prange check goes here + return ( + isinstance(r, Tuple) + and (len(r) == 3) + and (not isinstance(r.args[1], str)) and r.args[1].is_number + and (not isinstance(r.args[2], str)) and r.args[2].is_number + ) + + +def _unpack_args(*args): + """Given a list/tuple of arguments previously processed by _plot_sympify() + and/or _check_arguments(), separates and returns its components: + expressions, ranges, label and rendering keywords. + + Examples + ======== + + >>> from sympy import cos, sin, symbols + >>> from sympy.plotting.utils import _plot_sympify, _unpack_args + >>> x, y = symbols('x, y') + >>> args = (sin(x), (x, -10, 10), "f1") + >>> args = _plot_sympify(args) + >>> _unpack_args(*args) + ([sin(x)], [(x, -10, 10)], 'f1', None) + + >>> args = (sin(x**2 + y**2), (x, -2, 2), (y, -3, 3), "f2") + >>> args = _plot_sympify(args) + >>> _unpack_args(*args) + ([sin(x**2 + y**2)], [(x, -2, 2), (y, -3, 3)], 'f2', None) + + >>> args = (sin(x + y), cos(x - y), x + y, (x, -2, 2), (y, -3, 3), "f3") + >>> args = _plot_sympify(args) + >>> _unpack_args(*args) + ([sin(x + y), cos(x - y), x + y], [(x, -2, 2), (y, -3, 3)], 'f3', None) + """ + ranges = [t for t in args if _is_range(t)] + labels = [t for t in args if isinstance(t, str)] + label = None if not labels else labels[0] + rendering_kw = [t for t in args if isinstance(t, dict)] + rendering_kw = None if not rendering_kw else rendering_kw[0] + # NOTE: why None? because args might have been preprocessed by + # _check_arguments, so None might represent the rendering_kw + results = [not (_is_range(a) or isinstance(a, (str, dict)) or (a is None)) for a in args] + exprs = [a for a, b in zip(args, results) if b] + return exprs, ranges, label, rendering_kw + + +def _check_arguments(args, nexpr, npar, **kwargs): + """Checks the arguments and converts into tuples of the + form (exprs, ranges, label, rendering_kw). + + Parameters + ========== + + args + The arguments provided to the plot functions + nexpr + The number of sub-expression forming an expression to be plotted. + For example: + nexpr=1 for plot. + nexpr=2 for plot_parametric: a curve is represented by a tuple of two + elements. + nexpr=1 for plot3d. + nexpr=3 for plot3d_parametric_line: a curve is represented by a tuple + of three elements. + npar + The number of free symbols required by the plot functions. For example, + npar=1 for plot, npar=2 for plot3d, ... + **kwargs : + keyword arguments passed to the plotting function. It will be used to + verify if ``params`` has ben provided. + + Examples + ======== + + .. plot:: + :context: reset + :format: doctest + :include-source: True + + >>> from sympy import cos, sin, symbols + >>> from sympy.plotting.plot import _check_arguments + >>> x = symbols('x') + >>> _check_arguments([cos(x), sin(x)], 2, 1) + [(cos(x), sin(x), (x, -10, 10), None, None)] + + >>> _check_arguments([cos(x), sin(x), "test"], 2, 1) + [(cos(x), sin(x), (x, -10, 10), 'test', None)] + + >>> _check_arguments([cos(x), sin(x), "test", {"a": 0, "b": 1}], 2, 1) + [(cos(x), sin(x), (x, -10, 10), 'test', {'a': 0, 'b': 1})] + + >>> _check_arguments([x, x**2], 1, 1) + [(x, (x, -10, 10), None, None), (x**2, (x, -10, 10), None, None)] + """ + if not args: + return [] + output = [] + params = kwargs.get("params", None) + + if all(isinstance(a, (Expr, Relational, BooleanFunction)) for a in args[:nexpr]): + # In this case, with a single plot command, we are plotting either: + # 1. one expression + # 2. multiple expressions over the same range + + exprs, ranges, label, rendering_kw = _unpack_args(*args) + free_symbols = set().union(*[e.free_symbols for e in exprs]) + ranges = _create_ranges(exprs, ranges, npar, label, params) + + if nexpr > 1: + # in case of plot_parametric or plot3d_parametric_line, there will + # be 2 or 3 expressions defining a curve. Group them together. + if len(exprs) == nexpr: + exprs = (tuple(exprs),) + for expr in exprs: + # need this if-else to deal with both plot/plot3d and + # plot_parametric/plot3d_parametric_line + is_expr = isinstance(expr, (Expr, Relational, BooleanFunction)) + e = (expr,) if is_expr else expr + output.append((*e, *ranges, label, rendering_kw)) + + else: + # In this case, we are plotting multiple expressions, each one with its + # range. Each "expression" to be plotted has the following form: + # (expr, range, label) where label is optional + + _, ranges, labels, rendering_kw = _unpack_args(*args) + labels = [labels] if labels else [] + + # number of expressions + n = (len(ranges) + len(labels) + + (len(rendering_kw) if rendering_kw is not None else 0)) + new_args = args[:-n] if n > 0 else args + + # at this point, new_args might just be [expr]. But I need it to be + # [[expr]] in order to be able to loop over + # [expr, range [opt], label [opt]] + if not isinstance(new_args[0], (list, tuple, Tuple)): + new_args = [new_args] + + # Each arg has the form (expr1, expr2, ..., range1 [optional], ..., + # label [optional], rendering_kw [optional]) + for arg in new_args: + # look for "local" range and label. If there is not, use "global". + l = [a for a in arg if isinstance(a, str)] + if not l: + l = labels + r = [a for a in arg if _is_range(a)] + if not r: + r = ranges.copy() + rend_kw = [a for a in arg if isinstance(a, dict)] + rend_kw = rendering_kw if len(rend_kw) == 0 else rend_kw[0] + + # NOTE: arg = arg[:nexpr] may raise an exception if lambda + # functions are used. Execute the following instead: + arg = [arg[i] for i in range(nexpr)] + free_symbols = set() + if all(not callable(a) for a in arg): + free_symbols = free_symbols.union(*[a.free_symbols for a in arg]) + if len(r) != npar: + r = _create_ranges(arg, r, npar, "", params) + + label = None if not l else l[0] + output.append((*arg, *r, label, rend_kw)) + return output diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/c.py b/.venv/lib/python3.13/site-packages/sympy/printing/c.py new file mode 100644 index 0000000000000000000000000000000000000000..34c4b8f021073aeee7672248838557b5fa85fbae --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/c.py @@ -0,0 +1,747 @@ +""" +C code printer + +The C89CodePrinter & C99CodePrinter converts single SymPy expressions into +single C expressions, using the functions defined in math.h where possible. + +A complete code generator, which uses ccode extensively, can be found in +sympy.utilities.codegen. The codegen module can be used to generate complete +source code files that are compilable without further modifications. + + +""" + +from __future__ import annotations +from typing import Any + +from functools import wraps +from itertools import chain + +from sympy.core import S +from sympy.core.numbers import equal_valued, Float +from sympy.codegen.ast import ( + Assignment, Pointer, Variable, Declaration, Type, + real, complex_, integer, bool_, float32, float64, float80, + complex64, complex128, intc, value_const, pointer_const, + int8, int16, int32, int64, uint8, uint16, uint32, uint64, untyped, + none +) +from sympy.printing.codeprinter import CodePrinter, requires +from sympy.printing.precedence import precedence, PRECEDENCE +from sympy.sets.fancysets import Range + +# These are defined in the other file so we can avoid importing sympy.codegen +# from the top-level 'import sympy'. Export them here as well. +from sympy.printing.codeprinter import ccode, print_ccode # noqa:F401 + +# dictionary mapping SymPy function to (argument_conditions, C_function). +# Used in C89CodePrinter._print_Function(self) +known_functions_C89 = { + "Abs": [(lambda x: not x.is_integer, "fabs"), (lambda x: x.is_integer, "abs")], + "sin": "sin", + "cos": "cos", + "tan": "tan", + "asin": "asin", + "acos": "acos", + "atan": "atan", + "atan2": "atan2", + "exp": "exp", + "log": "log", + "log10": "log10", + "sinh": "sinh", + "cosh": "cosh", + "tanh": "tanh", + "floor": "floor", + "ceiling": "ceil", + "sqrt": "sqrt", # To enable automatic rewrites +} + +known_functions_C99 = dict(known_functions_C89, **{ + 'exp2': 'exp2', + 'expm1': 'expm1', + 'log2': 'log2', + 'log1p': 'log1p', + 'Cbrt': 'cbrt', + 'hypot': 'hypot', + 'fma': 'fma', + 'loggamma': 'lgamma', + 'erfc': 'erfc', + 'Max': 'fmax', + 'Min': 'fmin', + "asinh": "asinh", + "acosh": "acosh", + "atanh": "atanh", + "erf": "erf", + "gamma": "tgamma", +}) + +# These are the core reserved words in the C language. Taken from: +# https://en.cppreference.com/w/c/keyword + +reserved_words = [ + 'auto', 'break', 'case', 'char', 'const', 'continue', 'default', 'do', + 'double', 'else', 'enum', 'extern', 'float', 'for', 'goto', 'if', 'int', + 'long', 'register', 'return', 'short', 'signed', 'sizeof', 'static', + 'struct', 'entry', # never standardized, we'll leave it here anyway + 'switch', 'typedef', 'union', 'unsigned', 'void', 'volatile', 'while' +] + +reserved_words_c99 = ['inline', 'restrict'] + +def get_math_macros(): + """ Returns a dictionary with math-related macros from math.h/cmath + + Note that these macros are not strictly required by the C/C++-standard. + For MSVC they are enabled by defining "_USE_MATH_DEFINES" (preferably + via a compilation flag). + + Returns + ======= + + Dictionary mapping SymPy expressions to strings (macro names) + + """ + from sympy.codegen.cfunctions import log2, Sqrt + from sympy.functions.elementary.exponential import log + from sympy.functions.elementary.miscellaneous import sqrt + + return { + S.Exp1: 'M_E', + log2(S.Exp1): 'M_LOG2E', + 1/log(2): 'M_LOG2E', + log(2): 'M_LN2', + log(10): 'M_LN10', + S.Pi: 'M_PI', + S.Pi/2: 'M_PI_2', + S.Pi/4: 'M_PI_4', + 1/S.Pi: 'M_1_PI', + 2/S.Pi: 'M_2_PI', + 2/sqrt(S.Pi): 'M_2_SQRTPI', + 2/Sqrt(S.Pi): 'M_2_SQRTPI', + sqrt(2): 'M_SQRT2', + Sqrt(2): 'M_SQRT2', + 1/sqrt(2): 'M_SQRT1_2', + 1/Sqrt(2): 'M_SQRT1_2' + } + + +def _as_macro_if_defined(meth): + """ Decorator for printer methods + + When a Printer's method is decorated using this decorator the expressions printed + will first be looked for in the attribute ``math_macros``, and if present it will + print the macro name in ``math_macros`` followed by a type suffix for the type + ``real``. e.g. printing ``sympy.pi`` would print ``M_PIl`` if real is mapped to float80. + + """ + @wraps(meth) + def _meth_wrapper(self, expr, **kwargs): + if expr in self.math_macros: + return '%s%s' % (self.math_macros[expr], self._get_math_macro_suffix(real)) + else: + return meth(self, expr, **kwargs) + + return _meth_wrapper + + +class C89CodePrinter(CodePrinter): + """A printer to convert Python expressions to strings of C code""" + printmethod = "_ccode" + language = "C" + standard = "C89" + reserved_words = set(reserved_words) + + _default_settings: dict[str, Any] = dict(CodePrinter._default_settings, **{ + 'precision': 17, + 'user_functions': {}, + 'contract': True, + 'dereference': set(), + 'error_on_reserved': False, + }) + + type_aliases = { + real: float64, + complex_: complex128, + integer: intc + } + + type_mappings: dict[Type, Any] = { + real: 'double', + intc: 'int', + float32: 'float', + float64: 'double', + integer: 'int', + bool_: 'bool', + int8: 'int8_t', + int16: 'int16_t', + int32: 'int32_t', + int64: 'int64_t', + uint8: 'int8_t', + uint16: 'int16_t', + uint32: 'int32_t', + uint64: 'int64_t', + } + + type_headers = { + bool_: {'stdbool.h'}, + int8: {'stdint.h'}, + int16: {'stdint.h'}, + int32: {'stdint.h'}, + int64: {'stdint.h'}, + uint8: {'stdint.h'}, + uint16: {'stdint.h'}, + uint32: {'stdint.h'}, + uint64: {'stdint.h'}, + } + + # Macros needed to be defined when using a Type + type_macros: dict[Type, tuple[str, ...]] = {} + + type_func_suffixes = { + float32: 'f', + float64: '', + float80: 'l' + } + + type_literal_suffixes = { + float32: 'F', + float64: '', + float80: 'L' + } + + type_math_macro_suffixes = { + float80: 'l' + } + + math_macros = None + + _ns = '' # namespace, C++ uses 'std::' + # known_functions-dict to copy + _kf: dict[str, Any] = known_functions_C89 + + def __init__(self, settings=None): + settings = settings or {} + if self.math_macros is None: + self.math_macros = settings.pop('math_macros', get_math_macros()) + self.type_aliases = dict(chain(self.type_aliases.items(), + settings.pop('type_aliases', {}).items())) + self.type_mappings = dict(chain(self.type_mappings.items(), + settings.pop('type_mappings', {}).items())) + self.type_headers = dict(chain(self.type_headers.items(), + settings.pop('type_headers', {}).items())) + self.type_macros = dict(chain(self.type_macros.items(), + settings.pop('type_macros', {}).items())) + self.type_func_suffixes = dict(chain(self.type_func_suffixes.items(), + settings.pop('type_func_suffixes', {}).items())) + self.type_literal_suffixes = dict(chain(self.type_literal_suffixes.items(), + settings.pop('type_literal_suffixes', {}).items())) + self.type_math_macro_suffixes = dict(chain(self.type_math_macro_suffixes.items(), + settings.pop('type_math_macro_suffixes', {}).items())) + super().__init__(settings) + self.known_functions = dict(self._kf, **settings.get('user_functions', {})) + self._dereference = set(settings.get('dereference', [])) + self.headers = set() + self.libraries = set() + self.macros = set() + + def _rate_index_position(self, p): + return p*5 + + def _get_statement(self, codestring): + """ Get code string as a statement - i.e. ending with a semicolon. """ + return codestring if codestring.endswith(';') else codestring + ';' + + def _get_comment(self, text): + return "/* {} */".format(text) + + def _declare_number_const(self, name, value): + type_ = self.type_aliases[real] + var = Variable(name, type=type_, value=value.evalf(type_.decimal_dig), attrs={value_const}) + decl = Declaration(var) + return self._get_statement(self._print(decl)) + + def _format_code(self, lines): + return self.indent_code(lines) + + def _traverse_matrix_indices(self, mat): + rows, cols = mat.shape + return ((i, j) for i in range(rows) for j in range(cols)) + + @_as_macro_if_defined + def _print_Mul(self, expr, **kwargs): + return super()._print_Mul(expr, **kwargs) + + @_as_macro_if_defined + def _print_Pow(self, expr): + if "Pow" in self.known_functions: + return self._print_Function(expr) + PREC = precedence(expr) + suffix = self._get_func_suffix(real) + if equal_valued(expr.exp, -1): + return '%s/%s' % (self._print_Float(Float(1.0)), self.parenthesize(expr.base, PREC)) + elif equal_valued(expr.exp, 0.5): + return '%ssqrt%s(%s)' % (self._ns, suffix, self._print(expr.base)) + elif expr.exp == S.One/3 and self.standard != 'C89': + return '%scbrt%s(%s)' % (self._ns, suffix, self._print(expr.base)) + else: + return '%spow%s(%s, %s)' % (self._ns, suffix, self._print(expr.base), + self._print(expr.exp)) + + def _print_Mod(self, expr): + num, den = expr.args + if num.is_integer and den.is_integer: + PREC = precedence(expr) + snum, sden = [self.parenthesize(arg, PREC) for arg in expr.args] + # % is remainder (same sign as numerator), not modulo (same sign as + # denominator), in C. Hence, % only works as modulo if both numbers + # have the same sign + if (num.is_nonnegative and den.is_nonnegative or + num.is_nonpositive and den.is_nonpositive): + return f"{snum} % {sden}" + return f"(({snum} % {sden}) + {sden}) % {sden}" + # Not guaranteed integer + return self._print_math_func(expr, known='fmod') + + def _print_Rational(self, expr): + p, q = int(expr.p), int(expr.q) + suffix = self._get_literal_suffix(real) + return '%d.0%s/%d.0%s' % (p, suffix, q, suffix) + + def _print_Indexed(self, expr): + # calculate index for 1d array + offset = getattr(expr.base, 'offset', S.Zero) + strides = getattr(expr.base, 'strides', None) + indices = expr.indices + + if strides is None or isinstance(strides, str): + dims = expr.shape + shift = S.One + temp = () + if strides == 'C' or strides is None: + traversal = reversed(range(expr.rank)) + indices = indices[::-1] + elif strides == 'F': + traversal = range(expr.rank) + + for i in traversal: + temp += (shift,) + shift *= dims[i] + strides = temp + flat_index = sum(x[0]*x[1] for x in zip(indices, strides)) + offset + return "%s[%s]" % (self._print(expr.base.label), + self._print(flat_index)) + + @_as_macro_if_defined + def _print_NumberSymbol(self, expr): + return super()._print_NumberSymbol(expr) + + def _print_Infinity(self, expr): + return 'HUGE_VAL' + + def _print_NegativeInfinity(self, expr): + return '-HUGE_VAL' + + def _print_Piecewise(self, expr): + if expr.args[-1].cond != True: + # We need the last conditional to be a True, otherwise the resulting + # function may not return a result. + raise ValueError("All Piecewise expressions must contain an " + "(expr, True) statement to be used as a default " + "condition. Without one, the generated " + "expression may not evaluate to anything under " + "some condition.") + lines = [] + if expr.has(Assignment): + for i, (e, c) in enumerate(expr.args): + if i == 0: + lines.append("if (%s) {" % self._print(c)) + elif i == len(expr.args) - 1 and c == True: + lines.append("else {") + else: + lines.append("else if (%s) {" % self._print(c)) + code0 = self._print(e) + lines.append(code0) + lines.append("}") + return "\n".join(lines) + else: + # The piecewise was used in an expression, need to do inline + # operators. This has the downside that inline operators will + # not work for statements that span multiple lines (Matrix or + # Indexed expressions). + ecpairs = ["((%s) ? (\n%s\n)\n" % (self._print(c), + self._print(e)) + for e, c in expr.args[:-1]] + last_line = ": (\n%s\n)" % self._print(expr.args[-1].expr) + return ": ".join(ecpairs) + last_line + " ".join([")"*len(ecpairs)]) + + def _print_ITE(self, expr): + from sympy.functions import Piecewise + return self._print(expr.rewrite(Piecewise, deep=False)) + + def _print_MatrixElement(self, expr): + return "{}[{}]".format(self.parenthesize(expr.parent, PRECEDENCE["Atom"], + strict=True), expr.j + expr.i*expr.parent.shape[1]) + + def _print_Symbol(self, expr): + name = super()._print_Symbol(expr) + if expr in self._settings['dereference']: + return '(*{})'.format(name) + else: + return name + + def _print_Relational(self, expr): + lhs_code = self._print(expr.lhs) + rhs_code = self._print(expr.rhs) + op = expr.rel_op + return "{} {} {}".format(lhs_code, op, rhs_code) + + def _print_For(self, expr): + target = self._print(expr.target) + if isinstance(expr.iterable, Range): + start, stop, step = expr.iterable.args + else: + raise NotImplementedError("Only iterable currently supported is Range") + body = self._print(expr.body) + return ('for ({target} = {start}; {target} < {stop}; {target} += ' + '{step}) {{\n{body}\n}}').format(target=target, start=start, + stop=stop, step=step, body=body) + + def _print_sign(self, func): + return '((({0}) > 0) - (({0}) < 0))'.format(self._print(func.args[0])) + + def _print_Max(self, expr): + if "Max" in self.known_functions: + return self._print_Function(expr) + def inner_print_max(args): # The more natural abstraction of creating + if len(args) == 1: # and printing smaller Max objects is slow + return self._print(args[0]) # when there are many arguments. + half = len(args) // 2 + return "((%(a)s > %(b)s) ? %(a)s : %(b)s)" % { + 'a': inner_print_max(args[:half]), + 'b': inner_print_max(args[half:]) + } + return inner_print_max(expr.args) + + def _print_Min(self, expr): + if "Min" in self.known_functions: + return self._print_Function(expr) + def inner_print_min(args): # The more natural abstraction of creating + if len(args) == 1: # and printing smaller Min objects is slow + return self._print(args[0]) # when there are many arguments. + half = len(args) // 2 + return "((%(a)s < %(b)s) ? %(a)s : %(b)s)" % { + 'a': inner_print_min(args[:half]), + 'b': inner_print_min(args[half:]) + } + return inner_print_min(expr.args) + + def indent_code(self, code): + """Accepts a string of code or a list of code lines""" + + if isinstance(code, str): + code_lines = self.indent_code(code.splitlines(True)) + return ''.join(code_lines) + + tab = " " + inc_token = ('{', '(', '{\n', '(\n') + dec_token = ('}', ')') + + code = [line.lstrip(' \t') for line in code] + + increase = [int(any(map(line.endswith, inc_token))) for line in code] + decrease = [int(any(map(line.startswith, dec_token))) for line in code] + + pretty = [] + level = 0 + for n, line in enumerate(code): + if line in ('', '\n'): + pretty.append(line) + continue + level -= decrease[n] + pretty.append("%s%s" % (tab*level, line)) + level += increase[n] + return pretty + + def _get_func_suffix(self, type_): + return self.type_func_suffixes[self.type_aliases.get(type_, type_)] + + def _get_literal_suffix(self, type_): + return self.type_literal_suffixes[self.type_aliases.get(type_, type_)] + + def _get_math_macro_suffix(self, type_): + alias = self.type_aliases.get(type_, type_) + dflt = self.type_math_macro_suffixes.get(alias, '') + return self.type_math_macro_suffixes.get(type_, dflt) + + def _print_Tuple(self, expr): + return '{'+', '.join(self._print(e) for e in expr)+'}' + + _print_List = _print_Tuple + + def _print_Type(self, type_): + self.headers.update(self.type_headers.get(type_, set())) + self.macros.update(self.type_macros.get(type_, set())) + return self._print(self.type_mappings.get(type_, type_.name)) + + def _print_Declaration(self, decl): + from sympy.codegen.cnodes import restrict + var = decl.variable + val = var.value + if var.type == untyped: + raise ValueError("C does not support untyped variables") + + if isinstance(var, Pointer): + result = '{vc}{t} *{pc} {r}{s}'.format( + vc='const ' if value_const in var.attrs else '', + t=self._print(var.type), + pc=' const' if pointer_const in var.attrs else '', + r='restrict ' if restrict in var.attrs else '', + s=self._print(var.symbol) + ) + elif isinstance(var, Variable): + result = '{vc}{t} {s}'.format( + vc='const ' if value_const in var.attrs else '', + t=self._print(var.type), + s=self._print(var.symbol) + ) + else: + raise NotImplementedError("Unknown type of var: %s" % type(var)) + if val != None: # Must be "!= None", cannot be "is not None" + result += ' = %s' % self._print(val) + return result + + def _print_Float(self, flt): + type_ = self.type_aliases.get(real, real) + self.macros.update(self.type_macros.get(type_, set())) + suffix = self._get_literal_suffix(type_) + num = str(flt.evalf(type_.decimal_dig)) + if 'e' not in num and '.' not in num: + num += '.0' + num_parts = num.split('e') + num_parts[0] = num_parts[0].rstrip('0') + if num_parts[0].endswith('.'): + num_parts[0] += '0' + return 'e'.join(num_parts) + suffix + + @requires(headers={'stdbool.h'}) + def _print_BooleanTrue(self, expr): + return 'true' + + @requires(headers={'stdbool.h'}) + def _print_BooleanFalse(self, expr): + return 'false' + + def _print_Element(self, elem): + if elem.strides == None: # Must be "== None", cannot be "is None" + if elem.offset != None: # Must be "!= None", cannot be "is not None" + raise ValueError("Expected strides when offset is given") + idxs = ']['.join((self._print(arg) for arg in elem.indices)) + else: + global_idx = sum(i*s for i, s in zip(elem.indices, elem.strides)) + if elem.offset != None: # Must be "!= None", cannot be "is not None" + global_idx += elem.offset + idxs = self._print(global_idx) + + return "{symb}[{idxs}]".format( + symb=self._print(elem.symbol), + idxs=idxs + ) + + def _print_CodeBlock(self, expr): + """ Elements of code blocks printed as statements. """ + return '\n'.join([self._get_statement(self._print(i)) for i in expr.args]) + + def _print_While(self, expr): + return 'while ({condition}) {{\n{body}\n}}'.format(**expr.kwargs( + apply=lambda arg: self._print(arg))) + + def _print_Scope(self, expr): + return '{\n%s\n}' % self._print_CodeBlock(expr.body) + + @requires(headers={'stdio.h'}) + def _print_Print(self, expr): + if expr.file == none: + template = 'printf({fmt}, {pargs})' + else: + template = 'fprintf(%(out)s, {fmt}, {pargs})' % { + 'out': self._print(expr.file) + } + return template.format( + fmt="%s\n" if expr.format_string == none else self._print(expr.format_string), + pargs=', '.join((self._print(arg) for arg in expr.print_args)) + ) + + def _print_Stream(self, strm): + return strm.name + + def _print_FunctionPrototype(self, expr): + pars = ', '.join((self._print(Declaration(arg)) for arg in expr.parameters)) + return "%s %s(%s)" % ( + tuple((self._print(arg) for arg in (expr.return_type, expr.name))) + (pars,) + ) + + def _print_FunctionDefinition(self, expr): + return "%s%s" % (self._print_FunctionPrototype(expr), + self._print_Scope(expr)) + + def _print_Return(self, expr): + arg, = expr.args + return 'return %s' % self._print(arg) + + def _print_CommaOperator(self, expr): + return '(%s)' % ', '.join((self._print(arg) for arg in expr.args)) + + def _print_Label(self, expr): + if expr.body == none: + return '%s:' % str(expr.name) + if len(expr.body.args) == 1: + return '%s:\n%s' % (str(expr.name), self._print_CodeBlock(expr.body)) + return '%s:\n{\n%s\n}' % (str(expr.name), self._print_CodeBlock(expr.body)) + + def _print_goto(self, expr): + return 'goto %s' % expr.label.name + + def _print_PreIncrement(self, expr): + arg, = expr.args + return '++(%s)' % self._print(arg) + + def _print_PostIncrement(self, expr): + arg, = expr.args + return '(%s)++' % self._print(arg) + + def _print_PreDecrement(self, expr): + arg, = expr.args + return '--(%s)' % self._print(arg) + + def _print_PostDecrement(self, expr): + arg, = expr.args + return '(%s)--' % self._print(arg) + + def _print_struct(self, expr): + return "%(keyword)s %(name)s {\n%(lines)s}" % { + "keyword": expr.__class__.__name__, "name": expr.name, "lines": ';\n'.join( + [self._print(decl) for decl in expr.declarations] + ['']) + } + + def _print_BreakToken(self, _): + return 'break' + + def _print_ContinueToken(self, _): + return 'continue' + + _print_union = _print_struct + +class C99CodePrinter(C89CodePrinter): + standard = 'C99' + reserved_words = set(reserved_words + reserved_words_c99) + type_mappings=dict(chain(C89CodePrinter.type_mappings.items(), { + complex64: 'float complex', + complex128: 'double complex', + }.items())) + type_headers = dict(chain(C89CodePrinter.type_headers.items(), { + complex64: {'complex.h'}, + complex128: {'complex.h'} + }.items())) + + # known_functions-dict to copy + _kf: dict[str, Any] = known_functions_C99 + + # functions with versions with 'f' and 'l' suffixes: + _prec_funcs = ('fabs fmod remainder remquo fma fmax fmin fdim nan exp exp2' + ' expm1 log log10 log2 log1p pow sqrt cbrt hypot sin cos tan' + ' asin acos atan atan2 sinh cosh tanh asinh acosh atanh erf' + ' erfc tgamma lgamma ceil floor trunc round nearbyint rint' + ' frexp ldexp modf scalbn ilogb logb nextafter copysign').split() + + def _print_Infinity(self, expr): + return 'INFINITY' + + def _print_NegativeInfinity(self, expr): + return '-INFINITY' + + def _print_NaN(self, expr): + return 'NAN' + + # tgamma was already covered by 'known_functions' dict + + @requires(headers={'math.h'}, libraries={'m'}) + @_as_macro_if_defined + def _print_math_func(self, expr, nest=False, known=None): + if known is None: + known = self.known_functions[expr.__class__.__name__] + if not isinstance(known, str): + for cb, name in known: + if cb(*expr.args): + known = name + break + else: + raise ValueError("No matching printer") + try: + return known(self, *expr.args) + except TypeError: + suffix = self._get_func_suffix(real) if self._ns + known in self._prec_funcs else '' + + if nest: + args = self._print(expr.args[0]) + if len(expr.args) > 1: + paren_pile = '' + for curr_arg in expr.args[1:-1]: + paren_pile += ')' + args += ', {ns}{name}{suffix}({next}'.format( + ns=self._ns, + name=known, + suffix=suffix, + next = self._print(curr_arg) + ) + args += ', %s%s' % ( + self._print(expr.func(expr.args[-1])), + paren_pile + ) + else: + args = ', '.join((self._print(arg) for arg in expr.args)) + return '{ns}{name}{suffix}({args})'.format( + ns=self._ns, + name=known, + suffix=suffix, + args=args + ) + + def _print_Max(self, expr): + return self._print_math_func(expr, nest=True) + + def _print_Min(self, expr): + return self._print_math_func(expr, nest=True) + + def _get_loop_opening_ending(self, indices): + open_lines = [] + close_lines = [] + loopstart = "for (int %(var)s=%(start)s; %(var)s<%(end)s; %(var)s++){" # C99 + for i in indices: + # C arrays start at 0 and end at dimension-1 + open_lines.append(loopstart % { + 'var': self._print(i.label), + 'start': self._print(i.lower), + 'end': self._print(i.upper + 1)}) + close_lines.append("}") + return open_lines, close_lines + + +for k in ('Abs Sqrt exp exp2 expm1 log log10 log2 log1p Cbrt hypot fma' + ' loggamma sin cos tan asin acos atan atan2 sinh cosh tanh asinh acosh ' + 'atanh erf erfc loggamma gamma ceiling floor').split(): + setattr(C99CodePrinter, '_print_%s' % k, C99CodePrinter._print_math_func) + + +class C11CodePrinter(C99CodePrinter): + + @requires(headers={'stdalign.h'}) + def _print_alignof(self, expr): + arg, = expr.args + return 'alignof(%s)' % self._print(arg) + + +c_code_printers = { + 'c89': C89CodePrinter, + 'c99': C99CodePrinter, + 'c11': C11CodePrinter +} diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/conventions.py b/.venv/lib/python3.13/site-packages/sympy/printing/conventions.py new file mode 100644 index 0000000000000000000000000000000000000000..4f5545ae38511e0bb0366da5c9fb4e59156095c8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/conventions.py @@ -0,0 +1,88 @@ +""" +A few practical conventions common to all printers. +""" + +import re + +from collections.abc import Iterable +from sympy.core.function import Derivative + +_name_with_digits_p = re.compile(r'^([^\W\d_]+)(\d+)$', re.UNICODE) + + +def split_super_sub(text): + """Split a symbol name into a name, superscripts and subscripts + + The first part of the symbol name is considered to be its actual + 'name', followed by super- and subscripts. Each superscript is + preceded with a "^" character or by "__". Each subscript is preceded + by a "_" character. The three return values are the actual name, a + list with superscripts and a list with subscripts. + + Examples + ======== + + >>> from sympy.printing.conventions import split_super_sub + >>> split_super_sub('a_x^1') + ('a', ['1'], ['x']) + >>> split_super_sub('var_sub1__sup_sub2') + ('var', ['sup'], ['sub1', 'sub2']) + + """ + if not text: + return text, [], [] + + pos = 0 + name = None + supers = [] + subs = [] + while pos < len(text): + start = pos + 1 + if text[pos:pos + 2] == "__": + start += 1 + pos_hat = text.find("^", start) + if pos_hat < 0: + pos_hat = len(text) + pos_usc = text.find("_", start) + if pos_usc < 0: + pos_usc = len(text) + pos_next = min(pos_hat, pos_usc) + part = text[pos:pos_next] + pos = pos_next + if name is None: + name = part + elif part.startswith("^"): + supers.append(part[1:]) + elif part.startswith("__"): + supers.append(part[2:]) + elif part.startswith("_"): + subs.append(part[1:]) + else: + raise RuntimeError("This should never happen.") + + # Make a little exception when a name ends with digits, i.e. treat them + # as a subscript too. + m = _name_with_digits_p.match(name) + if m: + name, sub = m.groups() + subs.insert(0, sub) + + return name, supers, subs + + +def requires_partial(expr): + """Return whether a partial derivative symbol is required for printing + + This requires checking how many free variables there are, + filtering out the ones that are integers. Some expressions do not have + free variables. In that case, check its variable list explicitly to + get the context of the expression. + """ + + if isinstance(expr, Derivative): + return requires_partial(expr.expr) + + if not isinstance(expr.free_symbols, Iterable): + return len(set(expr.variables)) > 1 + + return sum(not s.is_integer for s in expr.free_symbols) > 1 diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/defaults.py b/.venv/lib/python3.13/site-packages/sympy/printing/defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..77a88d353fed4bd70496456ddd03cc429a4ba5e7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/defaults.py @@ -0,0 +1,5 @@ +from sympy.core._print_helpers import Printable + +# alias for compatibility +Printable.__module__ = __name__ +DefaultPrinting = Printable diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/dot.py b/.venv/lib/python3.13/site-packages/sympy/printing/dot.py new file mode 100644 index 0000000000000000000000000000000000000000..c968fee389c16108b757b8fcad531ac6fa4ddb2f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/dot.py @@ -0,0 +1,294 @@ +from sympy.core.basic import Basic +from sympy.core.expr import Expr +from sympy.core.symbol import Symbol +from sympy.core.numbers import Integer, Rational, Float +from sympy.printing.repr import srepr + +__all__ = ['dotprint'] + +default_styles = ( + (Basic, {'color': 'blue', 'shape': 'ellipse'}), + (Expr, {'color': 'black'}) +) + +slotClasses = (Symbol, Integer, Rational, Float) +def purestr(x, with_args=False): + """A string that follows ```obj = type(obj)(*obj.args)``` exactly. + + Parameters + ========== + + with_args : boolean, optional + If ``True``, there will be a second argument for the return + value, which is a tuple containing ``purestr`` applied to each + of the subnodes. + + If ``False``, there will not be a second argument for the + return. + + Default is ``False`` + + Examples + ======== + + >>> from sympy import Float, Symbol, MatrixSymbol + >>> from sympy import Integer # noqa: F401 + >>> from sympy.core.symbol import Str # noqa: F401 + >>> from sympy.printing.dot import purestr + + Applying ``purestr`` for basic symbolic object: + >>> code = purestr(Symbol('x')) + >>> code + "Symbol('x')" + >>> eval(code) == Symbol('x') + True + + For basic numeric object: + >>> purestr(Float(2)) + "Float('2.0', precision=53)" + + For matrix symbol: + >>> code = purestr(MatrixSymbol('x', 2, 2)) + >>> code + "MatrixSymbol(Str('x'), Integer(2), Integer(2))" + >>> eval(code) == MatrixSymbol('x', 2, 2) + True + + With ``with_args=True``: + >>> purestr(Float(2), with_args=True) + ("Float('2.0', precision=53)", ()) + >>> purestr(MatrixSymbol('x', 2, 2), with_args=True) + ("MatrixSymbol(Str('x'), Integer(2), Integer(2))", + ("Str('x')", 'Integer(2)', 'Integer(2)')) + """ + sargs = () + if not isinstance(x, Basic): + rv = str(x) + elif not x.args: + rv = srepr(x) + else: + args = x.args + sargs = tuple(map(purestr, args)) + rv = "%s(%s)"%(type(x).__name__, ', '.join(sargs)) + if with_args: + rv = rv, sargs + return rv + + +def styleof(expr, styles=default_styles): + """ Merge style dictionaries in order + + Examples + ======== + + >>> from sympy import Symbol, Basic, Expr, S + >>> from sympy.printing.dot import styleof + >>> styles = [(Basic, {'color': 'blue', 'shape': 'ellipse'}), + ... (Expr, {'color': 'black'})] + + >>> styleof(Basic(S(1)), styles) + {'color': 'blue', 'shape': 'ellipse'} + + >>> x = Symbol('x') + >>> styleof(x + 1, styles) # this is an Expr + {'color': 'black', 'shape': 'ellipse'} + """ + style = {} + for typ, sty in styles: + if isinstance(expr, typ): + style.update(sty) + return style + + +def attrprint(d, delimiter=', '): + """ Print a dictionary of attributes + + Examples + ======== + + >>> from sympy.printing.dot import attrprint + >>> print(attrprint({'color': 'blue', 'shape': 'ellipse'})) + "color"="blue", "shape"="ellipse" + """ + return delimiter.join('"%s"="%s"'%item for item in sorted(d.items())) + + +def dotnode(expr, styles=default_styles, labelfunc=str, pos=(), repeat=True): + """ String defining a node + + Examples + ======== + + >>> from sympy.printing.dot import dotnode + >>> from sympy.abc import x + >>> print(dotnode(x)) + "Symbol('x')_()" ["color"="black", "label"="x", "shape"="ellipse"]; + """ + style = styleof(expr, styles) + + if isinstance(expr, Basic) and not expr.is_Atom: + label = str(expr.__class__.__name__) + else: + label = labelfunc(expr) + style['label'] = label + expr_str = purestr(expr) + if repeat: + expr_str += '_%s' % str(pos) + return '"%s" [%s];' % (expr_str, attrprint(style)) + + +def dotedges(expr, atom=lambda x: not isinstance(x, Basic), pos=(), repeat=True): + """ List of strings for all expr->expr.arg pairs + + See the docstring of dotprint for explanations of the options. + + Examples + ======== + + >>> from sympy.printing.dot import dotedges + >>> from sympy.abc import x + >>> for e in dotedges(x+2): + ... print(e) + "Add(Integer(2), Symbol('x'))_()" -> "Integer(2)_(0,)"; + "Add(Integer(2), Symbol('x'))_()" -> "Symbol('x')_(1,)"; + """ + if atom(expr): + return [] + else: + expr_str, arg_strs = purestr(expr, with_args=True) + if repeat: + expr_str += '_%s' % str(pos) + arg_strs = ['%s_%s' % (a, str(pos + (i,))) + for i, a in enumerate(arg_strs)] + return ['"%s" -> "%s";' % (expr_str, a) for a in arg_strs] + +template = \ +"""digraph{ + +# Graph style +%(graphstyle)s + +######### +# Nodes # +######### + +%(nodes)s + +######### +# Edges # +######### + +%(edges)s +}""" + +_graphstyle = {'rankdir': 'TD', 'ordering': 'out'} + +def dotprint(expr, + styles=default_styles, atom=lambda x: not isinstance(x, Basic), + maxdepth=None, repeat=True, labelfunc=str, **kwargs): + """DOT description of a SymPy expression tree + + Parameters + ========== + + styles : list of lists composed of (Class, mapping), optional + Styles for different classes. + + The default is + + .. code-block:: python + + ( + (Basic, {'color': 'blue', 'shape': 'ellipse'}), + (Expr, {'color': 'black'}) + ) + + atom : function, optional + Function used to determine if an arg is an atom. + + A good choice is ``lambda x: not x.args``. + + The default is ``lambda x: not isinstance(x, Basic)``. + + maxdepth : integer, optional + The maximum depth. + + The default is ``None``, meaning no limit. + + repeat : boolean, optional + Whether to use different nodes for common subexpressions. + + The default is ``True``. + + For example, for ``x + x*y`` with ``repeat=True``, it will have + two nodes for ``x``; with ``repeat=False``, it will have one + node. + + .. warning:: + Even if a node appears twice in the same object like ``x`` in + ``Pow(x, x)``, it will still only appear once. + Hence, with ``repeat=False``, the number of arrows out of an + object might not equal the number of args it has. + + labelfunc : function, optional + A function to create a label for a given leaf node. + + The default is ``str``. + + Another good option is ``srepr``. + + For example with ``str``, the leaf nodes of ``x + 1`` are labeled, + ``x`` and ``1``. With ``srepr``, they are labeled ``Symbol('x')`` + and ``Integer(1)``. + + **kwargs : optional + Additional keyword arguments are included as styles for the graph. + + Examples + ======== + + >>> from sympy import dotprint + >>> from sympy.abc import x + >>> print(dotprint(x+2)) # doctest: +NORMALIZE_WHITESPACE + digraph{ + + # Graph style + "ordering"="out" + "rankdir"="TD" + + ######### + # Nodes # + ######### + + "Add(Integer(2), Symbol('x'))_()" ["color"="black", "label"="Add", "shape"="ellipse"]; + "Integer(2)_(0,)" ["color"="black", "label"="2", "shape"="ellipse"]; + "Symbol('x')_(1,)" ["color"="black", "label"="x", "shape"="ellipse"]; + + ######### + # Edges # + ######### + + "Add(Integer(2), Symbol('x'))_()" -> "Integer(2)_(0,)"; + "Add(Integer(2), Symbol('x'))_()" -> "Symbol('x')_(1,)"; + } + + """ + # repeat works by adding a signature tuple to the end of each node for its + # position in the graph. For example, for expr = Add(x, Pow(x, 2)), the x in the + # Pow will have the tuple (1, 0), meaning it is expr.args[1].args[0]. + graphstyle = _graphstyle.copy() + graphstyle.update(kwargs) + + nodes = [] + edges = [] + def traverse(e, depth, pos=()): + nodes.append(dotnode(e, styles, labelfunc=labelfunc, pos=pos, repeat=repeat)) + if maxdepth and depth >= maxdepth: + return + edges.extend(dotedges(e, atom=atom, pos=pos, repeat=repeat)) + [traverse(arg, depth+1, pos + (i,)) for i, arg in enumerate(e.args) if not atom(arg)] + traverse(expr, 0) + + return template%{'graphstyle': attrprint(graphstyle, delimiter='\n'), + 'nodes': '\n'.join(nodes), + 'edges': '\n'.join(edges)} diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/fortran.py b/.venv/lib/python3.13/site-packages/sympy/printing/fortran.py new file mode 100644 index 0000000000000000000000000000000000000000..7cea812d72ddcd3ccb56c7258f74e6fe3b8d5211 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/fortran.py @@ -0,0 +1,779 @@ +""" +Fortran code printer + +The FCodePrinter converts single SymPy expressions into single Fortran +expressions, using the functions defined in the Fortran 77 standard where +possible. Some useful pointers to Fortran can be found on wikipedia: + +https://en.wikipedia.org/wiki/Fortran + +Most of the code below is based on the "Professional Programmer\'s Guide to +Fortran77" by Clive G. Page: + +https://www.star.le.ac.uk/~cgp/prof77.html + +Fortran is a case-insensitive language. This might cause trouble because +SymPy is case sensitive. So, fcode adds underscores to variable names when +it is necessary to make them different for Fortran. +""" + +from __future__ import annotations +from typing import Any + +from collections import defaultdict +from itertools import chain +import string + +from sympy.codegen.ast import ( + Assignment, Declaration, Pointer, value_const, + float32, float64, float80, complex64, complex128, int8, int16, int32, + int64, intc, real, integer, bool_, complex_, none, stderr, stdout +) +from sympy.codegen.fnodes import ( + allocatable, isign, dsign, cmplx, merge, literal_dp, elemental, pure, + intent_in, intent_out, intent_inout +) +from sympy.core import S, Add, N, Float, Symbol +from sympy.core.function import Function +from sympy.core.numbers import equal_valued +from sympy.core.relational import Eq +from sympy.sets import Range +from sympy.printing.codeprinter import CodePrinter +from sympy.printing.precedence import precedence, PRECEDENCE +from sympy.printing.printer import printer_context + +# These are defined in the other file so we can avoid importing sympy.codegen +# from the top-level 'import sympy'. Export them here as well. +from sympy.printing.codeprinter import fcode, print_fcode # noqa:F401 + +known_functions = { + "sin": "sin", + "cos": "cos", + "tan": "tan", + "asin": "asin", + "acos": "acos", + "atan": "atan", + "atan2": "atan2", + "sinh": "sinh", + "cosh": "cosh", + "tanh": "tanh", + "log": "log", + "exp": "exp", + "erf": "erf", + "Abs": "abs", + "conjugate": "conjg", + "Max": "max", + "Min": "min", +} + + +class FCodePrinter(CodePrinter): + """A printer to convert SymPy expressions to strings of Fortran code""" + printmethod = "_fcode" + language = "Fortran" + + type_aliases = { + integer: int32, + real: float64, + complex_: complex128, + } + + type_mappings = { + intc: 'integer(c_int)', + float32: 'real*4', # real(kind(0.e0)) + float64: 'real*8', # real(kind(0.d0)) + float80: 'real*10', # real(kind(????)) + complex64: 'complex*8', + complex128: 'complex*16', + int8: 'integer*1', + int16: 'integer*2', + int32: 'integer*4', + int64: 'integer*8', + bool_: 'logical' + } + + type_modules = { + intc: {'iso_c_binding': 'c_int'} + } + + _default_settings: dict[str, Any] = dict(CodePrinter._default_settings, **{ + 'precision': 17, + 'user_functions': {}, + 'source_format': 'fixed', + 'contract': True, + 'standard': 77, + 'name_mangling': True, + }) + + _operators = { + 'and': '.and.', + 'or': '.or.', + 'xor': '.neqv.', + 'equivalent': '.eqv.', + 'not': '.not. ', + } + + _relationals = { + '!=': '/=', + } + + def __init__(self, settings=None): + if not settings: + settings = {} + self.mangled_symbols = {} # Dict showing mapping of all words + self.used_name = [] + self.type_aliases = dict(chain(self.type_aliases.items(), + settings.pop('type_aliases', {}).items())) + self.type_mappings = dict(chain(self.type_mappings.items(), + settings.pop('type_mappings', {}).items())) + super().__init__(settings) + self.known_functions = dict(known_functions) + userfuncs = settings.get('user_functions', {}) + self.known_functions.update(userfuncs) + # leading columns depend on fixed or free format + standards = {66, 77, 90, 95, 2003, 2008} + if self._settings['standard'] not in standards: + raise ValueError("Unknown Fortran standard: %s" % self._settings[ + 'standard']) + self.module_uses = defaultdict(set) # e.g.: use iso_c_binding, only: c_int + + @property + def _lead(self): + if self._settings['source_format'] == 'fixed': + return {'code': " ", 'cont': " @ ", 'comment': "C "} + elif self._settings['source_format'] == 'free': + return {'code': "", 'cont': " ", 'comment': "! "} + else: + raise ValueError("Unknown source format: %s" % self._settings['source_format']) + + def _print_Symbol(self, expr): + if self._settings['name_mangling'] == True: + if expr not in self.mangled_symbols: + name = expr.name + while name.lower() in self.used_name: + name += '_' + self.used_name.append(name.lower()) + if name == expr.name: + self.mangled_symbols[expr] = expr + else: + self.mangled_symbols[expr] = Symbol(name) + + expr = expr.xreplace(self.mangled_symbols) + + name = super()._print_Symbol(expr) + return name + + def _rate_index_position(self, p): + return -p*5 + + def _get_statement(self, codestring): + return codestring + + def _get_comment(self, text): + return "! {}".format(text) + + def _declare_number_const(self, name, value): + return "parameter ({} = {})".format(name, self._print(value)) + + def _print_NumberSymbol(self, expr): + # A Number symbol that is not implemented here or with _printmethod + # is registered and evaluated + self._number_symbols.add((expr, Float(expr.evalf(self._settings['precision'])))) + return str(expr) + + def _format_code(self, lines): + return self._wrap_fortran(self.indent_code(lines)) + + def _traverse_matrix_indices(self, mat): + rows, cols = mat.shape + return ((i, j) for j in range(cols) for i in range(rows)) + + def _get_loop_opening_ending(self, indices): + open_lines = [] + close_lines = [] + for i in indices: + # fortran arrays start at 1 and end at dimension + var, start, stop = map(self._print, + [i.label, i.lower + 1, i.upper + 1]) + open_lines.append("do %s = %s, %s" % (var, start, stop)) + close_lines.append("end do") + return open_lines, close_lines + + def _print_sign(self, expr): + from sympy.functions.elementary.complexes import Abs + arg, = expr.args + if arg.is_integer: + new_expr = merge(0, isign(1, arg), Eq(arg, 0)) + elif (arg.is_complex or arg.is_infinite): + new_expr = merge(cmplx(literal_dp(0), literal_dp(0)), arg/Abs(arg), Eq(Abs(arg), literal_dp(0))) + else: + new_expr = merge(literal_dp(0), dsign(literal_dp(1), arg), Eq(arg, literal_dp(0))) + return self._print(new_expr) + + + def _print_Piecewise(self, expr): + if expr.args[-1].cond != True: + # We need the last conditional to be a True, otherwise the resulting + # function may not return a result. + raise ValueError("All Piecewise expressions must contain an " + "(expr, True) statement to be used as a default " + "condition. Without one, the generated " + "expression may not evaluate to anything under " + "some condition.") + lines = [] + if expr.has(Assignment): + for i, (e, c) in enumerate(expr.args): + if i == 0: + lines.append("if (%s) then" % self._print(c)) + elif i == len(expr.args) - 1 and c == True: + lines.append("else") + else: + lines.append("else if (%s) then" % self._print(c)) + lines.append(self._print(e)) + lines.append("end if") + return "\n".join(lines) + elif self._settings["standard"] >= 95: + # Only supported in F95 and newer: + # The piecewise was used in an expression, need to do inline + # operators. This has the downside that inline operators will + # not work for statements that span multiple lines (Matrix or + # Indexed expressions). + pattern = "merge({T}, {F}, {COND})" + code = self._print(expr.args[-1].expr) + terms = list(expr.args[:-1]) + while terms: + e, c = terms.pop() + expr = self._print(e) + cond = self._print(c) + code = pattern.format(T=expr, F=code, COND=cond) + return code + else: + # `merge` is not supported prior to F95 + raise NotImplementedError("Using Piecewise as an expression using " + "inline operators is not supported in " + "standards earlier than Fortran95.") + + def _print_MatrixElement(self, expr): + return "{}({}, {})".format(self.parenthesize(expr.parent, + PRECEDENCE["Atom"], strict=True), expr.i + 1, expr.j + 1) + + def _print_Add(self, expr): + # purpose: print complex numbers nicely in Fortran. + # collect the purely real and purely imaginary parts: + pure_real = [] + pure_imaginary = [] + mixed = [] + for arg in expr.args: + if arg.is_number and arg.is_real: + pure_real.append(arg) + elif arg.is_number and arg.is_imaginary: + pure_imaginary.append(arg) + else: + mixed.append(arg) + if pure_imaginary: + if mixed: + PREC = precedence(expr) + term = Add(*mixed) + t = self._print(term) + if t.startswith('-'): + sign = "-" + t = t[1:] + else: + sign = "+" + if precedence(term) < PREC: + t = "(%s)" % t + + return "cmplx(%s,%s) %s %s" % ( + self._print(Add(*pure_real)), + self._print(-S.ImaginaryUnit*Add(*pure_imaginary)), + sign, t, + ) + else: + return "cmplx(%s,%s)" % ( + self._print(Add(*pure_real)), + self._print(-S.ImaginaryUnit*Add(*pure_imaginary)), + ) + else: + return CodePrinter._print_Add(self, expr) + + def _print_Function(self, expr): + # All constant function args are evaluated as floats + prec = self._settings['precision'] + args = [N(a, prec) for a in expr.args] + eval_expr = expr.func(*args) + if not isinstance(eval_expr, Function): + return self._print(eval_expr) + else: + return CodePrinter._print_Function(self, expr.func(*args)) + + def _print_Mod(self, expr): + # NOTE : Fortran has the functions mod() and modulo(). modulo() behaves + # the same wrt to the sign of the arguments as Python and SymPy's + # modulus computations (% and Mod()) but is not available in Fortran 66 + # or Fortran 77, thus we raise an error. + if self._settings['standard'] in [66, 77]: + msg = ("Python % operator and SymPy's Mod() function are not " + "supported by Fortran 66 or 77 standards.") + raise NotImplementedError(msg) + else: + x, y = expr.args + return " modulo({}, {})".format(self._print(x), self._print(y)) + + def _print_ImaginaryUnit(self, expr): + # purpose: print complex numbers nicely in Fortran. + return "cmplx(0,1)" + + def _print_int(self, expr): + return str(expr) + + def _print_Mul(self, expr): + # purpose: print complex numbers nicely in Fortran. + if expr.is_number and expr.is_imaginary: + return "cmplx(0,%s)" % ( + self._print(-S.ImaginaryUnit*expr) + ) + else: + return CodePrinter._print_Mul(self, expr) + + def _print_Pow(self, expr): + PREC = precedence(expr) + if equal_valued(expr.exp, -1): + return '%s/%s' % ( + self._print(literal_dp(1)), + self.parenthesize(expr.base, PREC) + ) + elif equal_valued(expr.exp, 0.5): + if expr.base.is_integer: + # Fortran intrinsic sqrt() does not accept integer argument + if expr.base.is_Number: + return 'sqrt(%s.0d0)' % self._print(expr.base) + else: + return 'sqrt(dble(%s))' % self._print(expr.base) + else: + return 'sqrt(%s)' % self._print(expr.base) + else: + return CodePrinter._print_Pow(self, expr) + + def _print_Rational(self, expr): + p, q = int(expr.p), int(expr.q) + return "%d.0d0/%d.0d0" % (p, q) + + def _print_Float(self, expr): + printed = CodePrinter._print_Float(self, expr) + e = printed.find('e') + if e > -1: + return "%sd%s" % (printed[:e], printed[e + 1:]) + return "%sd0" % printed + + def _print_Relational(self, expr): + lhs_code = self._print(expr.lhs) + rhs_code = self._print(expr.rhs) + op = expr.rel_op + op = op if op not in self._relationals else self._relationals[op] + return "{} {} {}".format(lhs_code, op, rhs_code) + + def _print_Indexed(self, expr): + inds = [ self._print(i) for i in expr.indices ] + return "%s(%s)" % (self._print(expr.base.label), ", ".join(inds)) + + def _print_AugmentedAssignment(self, expr): + lhs_code = self._print(expr.lhs) + rhs_code = self._print(expr.rhs) + return self._get_statement("{0} = {0} {1} {2}".format( + self._print(lhs_code), self._print(expr.binop), self._print(rhs_code))) + + def _print_sum_(self, sm): + params = self._print(sm.array) + if sm.dim != None: # Must use '!= None', cannot use 'is not None' + params += ', ' + self._print(sm.dim) + if sm.mask != None: # Must use '!= None', cannot use 'is not None' + params += ', mask=' + self._print(sm.mask) + return '%s(%s)' % (sm.__class__.__name__.rstrip('_'), params) + + def _print_product_(self, prod): + return self._print_sum_(prod) + + def _print_Do(self, do): + excl = ['concurrent'] + if do.step == 1: + excl.append('step') + step = '' + else: + step = ', {step}' + + return ( + 'do {concurrent}{counter} = {first}, {last}'+step+'\n' + '{body}\n' + 'end do\n' + ).format( + concurrent='concurrent ' if do.concurrent else '', + **do.kwargs(apply=lambda arg: self._print(arg), exclude=excl) + ) + + def _print_ImpliedDoLoop(self, idl): + step = '' if idl.step == 1 else ', {step}' + return ('({expr}, {counter} = {first}, {last}'+step+')').format( + **idl.kwargs(apply=lambda arg: self._print(arg)) + ) + + def _print_For(self, expr): + target = self._print(expr.target) + if isinstance(expr.iterable, Range): + start, stop, step = expr.iterable.args + else: + raise NotImplementedError("Only iterable currently supported is Range") + body = self._print(expr.body) + return ('do {target} = {start}, {stop}, {step}\n' + '{body}\n' + 'end do').format(target=target, start=start, stop=stop - 1, + step=step, body=body) + + def _print_Type(self, type_): + type_ = self.type_aliases.get(type_, type_) + type_str = self.type_mappings.get(type_, type_.name) + module_uses = self.type_modules.get(type_) + if module_uses: + for k, v in module_uses: + self.module_uses[k].add(v) + return type_str + + def _print_Element(self, elem): + return '{symbol}({idxs})'.format( + symbol=self._print(elem.symbol), + idxs=', '.join((self._print(arg) for arg in elem.indices)) + ) + + def _print_Extent(self, ext): + return str(ext) + + def _print_Declaration(self, expr): + var = expr.variable + val = var.value + dim = var.attr_params('dimension') + intents = [intent in var.attrs for intent in (intent_in, intent_out, intent_inout)] + if intents.count(True) == 0: + intent = '' + elif intents.count(True) == 1: + intent = ', intent(%s)' % ['in', 'out', 'inout'][intents.index(True)] + else: + raise ValueError("Multiple intents specified for %s" % self) + + if isinstance(var, Pointer): + raise NotImplementedError("Pointers are not available by default in Fortran.") + if self._settings["standard"] >= 90: + result = '{t}{vc}{dim}{intent}{alloc} :: {s}'.format( + t=self._print(var.type), + vc=', parameter' if value_const in var.attrs else '', + dim=', dimension(%s)' % ', '.join((self._print(arg) for arg in dim)) if dim else '', + intent=intent, + alloc=', allocatable' if allocatable in var.attrs else '', + s=self._print(var.symbol) + ) + if val != None: # Must be "!= None", cannot be "is not None" + result += ' = %s' % self._print(val) + else: + if value_const in var.attrs or val: + raise NotImplementedError("F77 init./parameter statem. req. multiple lines.") + result = ' '.join((self._print(arg) for arg in [var.type, var.symbol])) + + return result + + + def _print_Infinity(self, expr): + return '(huge(%s) + 1)' % self._print(literal_dp(0)) + + def _print_While(self, expr): + return 'do while ({condition})\n{body}\nend do'.format(**expr.kwargs( + apply=lambda arg: self._print(arg))) + + def _print_BooleanTrue(self, expr): + return '.true.' + + def _print_BooleanFalse(self, expr): + return '.false.' + + def _pad_leading_columns(self, lines): + result = [] + for line in lines: + if line.startswith('!'): + result.append(self._lead['comment'] + line[1:].lstrip()) + else: + result.append(self._lead['code'] + line) + return result + + def _wrap_fortran(self, lines): + """Wrap long Fortran lines + + Argument: + lines -- a list of lines (without \\n character) + + A comment line is split at white space. Code lines are split with a more + complex rule to give nice results. + """ + # routine to find split point in a code line + my_alnum = set("_+-." + string.digits + string.ascii_letters) + my_white = set(" \t()") + + def split_pos_code(line, endpos): + if len(line) <= endpos: + return len(line) + pos = endpos + split = lambda pos: \ + (line[pos] in my_alnum and line[pos - 1] not in my_alnum) or \ + (line[pos] not in my_alnum and line[pos - 1] in my_alnum) or \ + (line[pos] in my_white and line[pos - 1] not in my_white) or \ + (line[pos] not in my_white and line[pos - 1] in my_white) + while not split(pos): + pos -= 1 + if pos == 0: + return endpos + return pos + # split line by line and add the split lines to result + result = [] + if self._settings['source_format'] == 'free': + trailing = ' &' + else: + trailing = '' + for line in lines: + if line.startswith(self._lead['comment']): + # comment line + if len(line) > 72: + pos = line.rfind(" ", 6, 72) + if pos == -1: + pos = 72 + hunk = line[:pos] + line = line[pos:].lstrip() + result.append(hunk) + while line: + pos = line.rfind(" ", 0, 66) + if pos == -1 or len(line) < 66: + pos = 66 + hunk = line[:pos] + line = line[pos:].lstrip() + result.append("%s%s" % (self._lead['comment'], hunk)) + else: + result.append(line) + elif line.startswith(self._lead['code']): + # code line + pos = split_pos_code(line, 72) + hunk = line[:pos].rstrip() + line = line[pos:].lstrip() + if line: + hunk += trailing + result.append(hunk) + while line: + pos = split_pos_code(line, 65) + hunk = line[:pos].rstrip() + line = line[pos:].lstrip() + if line: + hunk += trailing + result.append("%s%s" % (self._lead['cont'], hunk)) + else: + result.append(line) + return result + + def indent_code(self, code): + """Accepts a string of code or a list of code lines""" + if isinstance(code, str): + code_lines = self.indent_code(code.splitlines(True)) + return ''.join(code_lines) + + free = self._settings['source_format'] == 'free' + code = [ line.lstrip(' \t') for line in code ] + + inc_keyword = ('do ', 'if(', 'if ', 'do\n', 'else', 'program', 'interface') + dec_keyword = ('end do', 'enddo', 'end if', 'endif', 'else', 'end program', 'end interface') + + increase = [ int(any(map(line.startswith, inc_keyword))) + for line in code ] + decrease = [ int(any(map(line.startswith, dec_keyword))) + for line in code ] + continuation = [ int(any(map(line.endswith, ['&', '&\n']))) + for line in code ] + + level = 0 + cont_padding = 0 + tabwidth = 3 + new_code = [] + for i, line in enumerate(code): + if line in ('', '\n'): + new_code.append(line) + continue + level -= decrease[i] + + if free: + padding = " "*(level*tabwidth + cont_padding) + else: + padding = " "*level*tabwidth + + line = "%s%s" % (padding, line) + if not free: + line = self._pad_leading_columns([line])[0] + + new_code.append(line) + + if continuation[i]: + cont_padding = 2*tabwidth + else: + cont_padding = 0 + level += increase[i] + + if not free: + return self._wrap_fortran(new_code) + return new_code + + def _print_GoTo(self, goto): + if goto.expr: # computed goto + return "go to ({labels}), {expr}".format( + labels=', '.join((self._print(arg) for arg in goto.labels)), + expr=self._print(goto.expr) + ) + else: + lbl, = goto.labels + return "go to %s" % self._print(lbl) + + def _print_Program(self, prog): + return ( + "program {name}\n" + "{body}\n" + "end program\n" + ).format(**prog.kwargs(apply=lambda arg: self._print(arg))) + + def _print_Module(self, mod): + return ( + "module {name}\n" + "{declarations}\n" + "\ncontains\n\n" + "{definitions}\n" + "end module\n" + ).format(**mod.kwargs(apply=lambda arg: self._print(arg))) + + def _print_Stream(self, strm): + if strm.name == 'stdout' and self._settings["standard"] >= 2003: + self.module_uses['iso_c_binding'].add('stdint=>input_unit') + return 'input_unit' + elif strm.name == 'stderr' and self._settings["standard"] >= 2003: + self.module_uses['iso_c_binding'].add('stdint=>error_unit') + return 'error_unit' + else: + if strm.name == 'stdout': + return '*' + else: + return strm.name + + def _print_Print(self, ps): + if ps.format_string == none: # Must be '!= None', cannot be 'is not None' + template = "print {fmt}, {iolist}" + fmt = '*' + else: + template = 'write(%(out)s, fmt="{fmt}", advance="no"), {iolist}' % { + 'out': {stderr: '0', stdout: '6'}.get(ps.file, '*') + } + fmt = self._print(ps.format_string) + return template.format(fmt=fmt, iolist=', '.join( + (self._print(arg) for arg in ps.print_args))) + + def _print_Return(self, rs): + arg, = rs.args + return "{result_name} = {arg}".format( + result_name=self._context.get('result_name', 'sympy_result'), + arg=self._print(arg) + ) + + def _print_FortranReturn(self, frs): + arg, = frs.args + if arg: + return 'return %s' % self._print(arg) + else: + return 'return' + + def _head(self, entity, fp, **kwargs): + bind_C_params = fp.attr_params('bind_C') + if bind_C_params is None: + bind = '' + else: + bind = ' bind(C, name="%s")' % bind_C_params[0] if bind_C_params else ' bind(C)' + result_name = self._settings.get('result_name', None) + return ( + "{entity}{name}({arg_names}){result}{bind}\n" + "{arg_declarations}" + ).format( + entity=entity, + name=self._print(fp.name), + arg_names=', '.join([self._print(arg.symbol) for arg in fp.parameters]), + result=(' result(%s)' % result_name) if result_name else '', + bind=bind, + arg_declarations='\n'.join((self._print(Declaration(arg)) for arg in fp.parameters)) + ) + + def _print_FunctionPrototype(self, fp): + entity = "{} function ".format(self._print(fp.return_type)) + return ( + "interface\n" + "{function_head}\n" + "end function\n" + "end interface" + ).format(function_head=self._head(entity, fp)) + + def _print_FunctionDefinition(self, fd): + if elemental in fd.attrs: + prefix = 'elemental ' + elif pure in fd.attrs: + prefix = 'pure ' + else: + prefix = '' + + entity = "{} function ".format(self._print(fd.return_type)) + with printer_context(self, result_name=fd.name): + return ( + "{prefix}{function_head}\n" + "{body}\n" + "end function\n" + ).format( + prefix=prefix, + function_head=self._head(entity, fd), + body=self._print(fd.body) + ) + + def _print_Subroutine(self, sub): + return ( + '{subroutine_head}\n' + '{body}\n' + 'end subroutine\n' + ).format( + subroutine_head=self._head('subroutine ', sub), + body=self._print(sub.body) + ) + + def _print_SubroutineCall(self, scall): + return 'call {name}({args})'.format( + name=self._print(scall.name), + args=', '.join((self._print(arg) for arg in scall.subroutine_args)) + ) + + def _print_use_rename(self, rnm): + return "%s => %s" % tuple((self._print(arg) for arg in rnm.args)) + + def _print_use(self, use): + result = 'use %s' % self._print(use.namespace) + if use.rename != None: # Must be '!= None', cannot be 'is not None' + result += ', ' + ', '.join([self._print(rnm) for rnm in use.rename]) + if use.only != None: # Must be '!= None', cannot be 'is not None' + result += ', only: ' + ', '.join([self._print(nly) for nly in use.only]) + return result + + def _print_BreakToken(self, _): + return 'exit' + + def _print_ContinueToken(self, _): + return 'cycle' + + def _print_ArrayConstructor(self, ac): + fmtstr = "[%s]" if self._settings["standard"] >= 2003 else '(/%s/)' + return fmtstr % ', '.join((self._print(arg) for arg in ac.elements)) + + def _print_ArrayElement(self, elem): + return '{symbol}({idxs})'.format( + symbol=self._print(elem.name), + idxs=', '.join((self._print(arg) for arg in elem.indices)) + ) diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/julia.py b/.venv/lib/python3.13/site-packages/sympy/printing/julia.py new file mode 100644 index 0000000000000000000000000000000000000000..3ab815add4d87fe953f646409e3a7bb383b1bbc6 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/julia.py @@ -0,0 +1,652 @@ +""" +Julia code printer + +The `JuliaCodePrinter` converts SymPy expressions into Julia expressions. + +A complete code generator, which uses `julia_code` extensively, can be found +in `sympy.utilities.codegen`. The `codegen` module can be used to generate +complete source code files. + +""" + +from __future__ import annotations +from typing import Any + +from sympy.core import Mul, Pow, S, Rational +from sympy.core.mul import _keep_coeff +from sympy.core.numbers import equal_valued +from sympy.printing.codeprinter import CodePrinter +from sympy.printing.precedence import precedence, PRECEDENCE +from re import search + +# List of known functions. First, those that have the same name in +# SymPy and Julia. This is almost certainly incomplete! +known_fcns_src1 = ["sin", "cos", "tan", "cot", "sec", "csc", + "asin", "acos", "atan", "acot", "asec", "acsc", + "sinh", "cosh", "tanh", "coth", "sech", "csch", + "asinh", "acosh", "atanh", "acoth", "asech", "acsch", + "atan2", "sign", "floor", "log", "exp", + "cbrt", "sqrt", "erf", "erfc", "erfi", + "factorial", "gamma", "digamma", "trigamma", + "polygamma", "beta", + "airyai", "airyaiprime", "airybi", "airybiprime", + "besselj", "bessely", "besseli", "besselk", + "erfinv", "erfcinv"] +# These functions have different names ("SymPy": "Julia"), more +# generally a mapping to (argument_conditions, julia_function). +known_fcns_src2 = { + "Abs": "abs", + "ceiling": "ceil", + "conjugate": "conj", + "hankel1": "hankelh1", + "hankel2": "hankelh2", + "im": "imag", + "re": "real" +} + + +class JuliaCodePrinter(CodePrinter): + """ + A printer to convert expressions to strings of Julia code. + """ + printmethod = "_julia" + language = "Julia" + + _operators = { + 'and': '&&', + 'or': '||', + 'not': '!', + } + + _default_settings: dict[str, Any] = dict(CodePrinter._default_settings, **{ + 'precision': 17, + 'user_functions': {}, + 'contract': True, + 'inline': True, + }) + # Note: contract is for expressing tensors as loops (if True), or just + # assignment (if False). FIXME: this should be looked a more carefully + # for Julia. + + def __init__(self, settings={}): + super().__init__(settings) + self.known_functions = dict(zip(known_fcns_src1, known_fcns_src1)) + self.known_functions.update(dict(known_fcns_src2)) + userfuncs = settings.get('user_functions', {}) + self.known_functions.update(userfuncs) + + + def _rate_index_position(self, p): + return p*5 + + + def _get_statement(self, codestring): + return "%s" % codestring + + + def _get_comment(self, text): + return "# {}".format(text) + + + def _declare_number_const(self, name, value): + return "const {} = {}".format(name, value) + + + def _format_code(self, lines): + return self.indent_code(lines) + + + def _traverse_matrix_indices(self, mat): + # Julia uses Fortran order (column-major) + rows, cols = mat.shape + return ((i, j) for j in range(cols) for i in range(rows)) + + + def _get_loop_opening_ending(self, indices): + open_lines = [] + close_lines = [] + for i in indices: + # Julia arrays start at 1 and end at dimension + var, start, stop = map(self._print, + [i.label, i.lower + 1, i.upper + 1]) + open_lines.append("for %s = %s:%s" % (var, start, stop)) + close_lines.append("end") + return open_lines, close_lines + + + def _print_Mul(self, expr): + # print complex numbers nicely in Julia + if (expr.is_number and expr.is_imaginary and + expr.as_coeff_Mul()[0].is_integer): + return "%sim" % self._print(-S.ImaginaryUnit*expr) + + # cribbed from str.py + prec = precedence(expr) + + c, e = expr.as_coeff_Mul() + if c < 0: + expr = _keep_coeff(-c, e) + sign = "-" + else: + sign = "" + + a = [] # items in the numerator + b = [] # items that are in the denominator (if any) + + pow_paren = [] # Will collect all pow with more than one base element and exp = -1 + + if self.order not in ('old', 'none'): + args = expr.as_ordered_factors() + else: + # use make_args in case expr was something like -x -> x + args = Mul.make_args(expr) + + # Gather args for numerator/denominator + for item in args: + if (item.is_commutative and item.is_Pow and item.exp.is_Rational + and item.exp.is_negative): + if item.exp != -1: + b.append(Pow(item.base, -item.exp, evaluate=False)) + else: + if len(item.args[0].args) != 1 and isinstance(item.base, Mul): # To avoid situations like #14160 + pow_paren.append(item) + b.append(Pow(item.base, -item.exp)) + elif item.is_Rational and item is not S.Infinity and item.p == 1: + # Save the Rational type in julia Unless the numerator is 1. + # For example: + # julia_code(Rational(3, 7)*x) --> (3 // 7) * x + # julia_code(x/3) --> x / 3 but not x * (1 // 3) + b.append(Rational(item.q)) + else: + a.append(item) + + a = a or [S.One] + + a_str = [self.parenthesize(x, prec) for x in a] + b_str = [self.parenthesize(x, prec) for x in b] + + # To parenthesize Pow with exp = -1 and having more than one Symbol + for item in pow_paren: + if item.base in b: + b_str[b.index(item.base)] = "(%s)" % b_str[b.index(item.base)] + + # from here it differs from str.py to deal with "*" and ".*" + def multjoin(a, a_str): + # here we probably are assuming the constants will come first + r = a_str[0] + for i in range(1, len(a)): + mulsym = '*' if a[i-1].is_number else '.*' + r = "%s %s %s" % (r, mulsym, a_str[i]) + return r + + if not b: + return sign + multjoin(a, a_str) + elif len(b) == 1: + divsym = '/' if b[0].is_number else './' + return "%s %s %s" % (sign+multjoin(a, a_str), divsym, b_str[0]) + else: + divsym = '/' if all(bi.is_number for bi in b) else './' + return "%s %s (%s)" % (sign + multjoin(a, a_str), divsym, multjoin(b, b_str)) + + def _print_Relational(self, expr): + lhs_code = self._print(expr.lhs) + rhs_code = self._print(expr.rhs) + op = expr.rel_op + return "{} {} {}".format(lhs_code, op, rhs_code) + + def _print_Pow(self, expr): + powsymbol = '^' if all(x.is_number for x in expr.args) else '.^' + + PREC = precedence(expr) + + if equal_valued(expr.exp, 0.5): + return "sqrt(%s)" % self._print(expr.base) + + if expr.is_commutative: + if equal_valued(expr.exp, -0.5): + sym = '/' if expr.base.is_number else './' + return "1 %s sqrt(%s)" % (sym, self._print(expr.base)) + if equal_valued(expr.exp, -1): + sym = '/' if expr.base.is_number else './' + return "1 %s %s" % (sym, self.parenthesize(expr.base, PREC)) + + return '%s %s %s' % (self.parenthesize(expr.base, PREC), powsymbol, + self.parenthesize(expr.exp, PREC)) + + + def _print_MatPow(self, expr): + PREC = precedence(expr) + return '%s ^ %s' % (self.parenthesize(expr.base, PREC), + self.parenthesize(expr.exp, PREC)) + + + def _print_Pi(self, expr): + if self._settings["inline"]: + return "pi" + else: + return super()._print_NumberSymbol(expr) + + + def _print_ImaginaryUnit(self, expr): + return "im" + + + def _print_Exp1(self, expr): + if self._settings["inline"]: + return "e" + else: + return super()._print_NumberSymbol(expr) + + + def _print_EulerGamma(self, expr): + if self._settings["inline"]: + return "eulergamma" + else: + return super()._print_NumberSymbol(expr) + + + def _print_Catalan(self, expr): + if self._settings["inline"]: + return "catalan" + else: + return super()._print_NumberSymbol(expr) + + + def _print_GoldenRatio(self, expr): + if self._settings["inline"]: + return "golden" + else: + return super()._print_NumberSymbol(expr) + + + def _print_Assignment(self, expr): + from sympy.codegen.ast import Assignment + from sympy.functions.elementary.piecewise import Piecewise + from sympy.tensor.indexed import IndexedBase + # Copied from codeprinter, but remove special MatrixSymbol treatment + lhs = expr.lhs + rhs = expr.rhs + # We special case assignments that take multiple lines + if not self._settings["inline"] and isinstance(expr.rhs, Piecewise): + # Here we modify Piecewise so each expression is now + # an Assignment, and then continue on the print. + expressions = [] + conditions = [] + for (e, c) in rhs.args: + expressions.append(Assignment(lhs, e)) + conditions.append(c) + temp = Piecewise(*zip(expressions, conditions)) + return self._print(temp) + if self._settings["contract"] and (lhs.has(IndexedBase) or + rhs.has(IndexedBase)): + # Here we check if there is looping to be done, and if so + # print the required loops. + return self._doprint_loops(rhs, lhs) + else: + lhs_code = self._print(lhs) + rhs_code = self._print(rhs) + return self._get_statement("%s = %s" % (lhs_code, rhs_code)) + + + def _print_Infinity(self, expr): + return 'Inf' + + + def _print_NegativeInfinity(self, expr): + return '-Inf' + + + def _print_NaN(self, expr): + return 'NaN' + + + def _print_list(self, expr): + return 'Any[' + ', '.join(self._print(a) for a in expr) + ']' + + + def _print_tuple(self, expr): + if len(expr) == 1: + return "(%s,)" % self._print(expr[0]) + else: + return "(%s)" % self.stringify(expr, ", ") + _print_Tuple = _print_tuple + + + def _print_BooleanTrue(self, expr): + return "true" + + + def _print_BooleanFalse(self, expr): + return "false" + + + def _print_bool(self, expr): + return str(expr).lower() + + + # Could generate quadrature code for definite Integrals? + #_print_Integral = _print_not_supported + + + def _print_MatrixBase(self, A): + # Handle zero dimensions: + if S.Zero in A.shape: + return 'zeros(%s, %s)' % (A.rows, A.cols) + elif (A.rows, A.cols) == (1, 1): + return "[%s]" % A[0, 0] + elif A.rows == 1: + return "[%s]" % A.table(self, rowstart='', rowend='', colsep=' ') + elif A.cols == 1: + # note .table would unnecessarily equispace the rows + return "[%s]" % ", ".join([self._print(a) for a in A]) + return "[%s]" % A.table(self, rowstart='', rowend='', + rowsep=';\n', colsep=' ') + + + def _print_SparseRepMatrix(self, A): + from sympy.matrices import Matrix + L = A.col_list() + # make row vectors of the indices and entries + I = Matrix([k[0] + 1 for k in L]) + J = Matrix([k[1] + 1 for k in L]) + AIJ = Matrix([k[2] for k in L]) + return "sparse(%s, %s, %s, %s, %s)" % (self._print(I), self._print(J), + self._print(AIJ), A.rows, A.cols) + + + def _print_MatrixElement(self, expr): + return self.parenthesize(expr.parent, PRECEDENCE["Atom"], strict=True) \ + + '[%s,%s]' % (expr.i + 1, expr.j + 1) + + + def _print_MatrixSlice(self, expr): + def strslice(x, lim): + l = x[0] + 1 + h = x[1] + step = x[2] + lstr = self._print(l) + hstr = 'end' if h == lim else self._print(h) + if step == 1: + if l == 1 and h == lim: + return ':' + if l == h: + return lstr + else: + return lstr + ':' + hstr + else: + return ':'.join((lstr, self._print(step), hstr)) + return (self._print(expr.parent) + '[' + + strslice(expr.rowslice, expr.parent.shape[0]) + ',' + + strslice(expr.colslice, expr.parent.shape[1]) + ']') + + + def _print_Indexed(self, expr): + inds = [ self._print(i) for i in expr.indices ] + return "%s[%s]" % (self._print(expr.base.label), ",".join(inds)) + + def _print_Identity(self, expr): + return "eye(%s)" % self._print(expr.shape[0]) + + def _print_HadamardProduct(self, expr): + return ' .* '.join([self.parenthesize(arg, precedence(expr)) + for arg in expr.args]) + + def _print_HadamardPower(self, expr): + PREC = precedence(expr) + return '.**'.join([ + self.parenthesize(expr.base, PREC), + self.parenthesize(expr.exp, PREC) + ]) + + def _print_Rational(self, expr): + if expr.q == 1: + return str(expr.p) + return "%s // %s" % (expr.p, expr.q) + + # Note: as of 2022, Julia doesn't have spherical Bessel functions + def _print_jn(self, expr): + from sympy.functions import sqrt, besselj + x = expr.argument + expr2 = sqrt(S.Pi/(2*x))*besselj(expr.order + S.Half, x) + return self._print(expr2) + + + def _print_yn(self, expr): + from sympy.functions import sqrt, bessely + x = expr.argument + expr2 = sqrt(S.Pi/(2*x))*bessely(expr.order + S.Half, x) + return self._print(expr2) + + def _print_sinc(self, expr): + # Julia has the normalized sinc function + return "sinc({})".format(self._print(expr.args[0] / S.Pi)) + + def _print_Piecewise(self, expr): + if expr.args[-1].cond != True: + # We need the last conditional to be a True, otherwise the resulting + # function may not return a result. + raise ValueError("All Piecewise expressions must contain an " + "(expr, True) statement to be used as a default " + "condition. Without one, the generated " + "expression may not evaluate to anything under " + "some condition.") + lines = [] + if self._settings["inline"]: + # Express each (cond, expr) pair in a nested Horner form: + # (condition) .* (expr) + (not cond) .* () + # Expressions that result in multiple statements won't work here. + ecpairs = ["({}) ? ({}) :".format + (self._print(c), self._print(e)) + for e, c in expr.args[:-1]] + elast = " (%s)" % self._print(expr.args[-1].expr) + pw = "\n".join(ecpairs) + elast + # Note: current need these outer brackets for 2*pw. Would be + # nicer to teach parenthesize() to do this for us when needed! + return "(" + pw + ")" + else: + for i, (e, c) in enumerate(expr.args): + if i == 0: + lines.append("if (%s)" % self._print(c)) + elif i == len(expr.args) - 1 and c == True: + lines.append("else") + else: + lines.append("elseif (%s)" % self._print(c)) + code0 = self._print(e) + lines.append(code0) + if i == len(expr.args) - 1: + lines.append("end") + return "\n".join(lines) + + def _print_MatMul(self, expr): + c, m = expr.as_coeff_mmul() + + sign = "" + if c.is_number: + re, im = c.as_real_imag() + if im.is_zero and re.is_negative: + expr = _keep_coeff(-c, m) + sign = "-" + elif re.is_zero and im.is_negative: + expr = _keep_coeff(-c, m) + sign = "-" + + return sign + ' * '.join( + (self.parenthesize(arg, precedence(expr)) for arg in expr.args) + ) + + + def indent_code(self, code): + """Accepts a string of code or a list of code lines""" + + # code mostly copied from ccode + if isinstance(code, str): + code_lines = self.indent_code(code.splitlines(True)) + return ''.join(code_lines) + + tab = " " + inc_regex = ('^function ', '^if ', '^elseif ', '^else$', '^for ') + dec_regex = ('^end$', '^elseif ', '^else$') + + # pre-strip left-space from the code + code = [ line.lstrip(' \t') for line in code ] + + increase = [ int(any(search(re, line) for re in inc_regex)) + for line in code ] + decrease = [ int(any(search(re, line) for re in dec_regex)) + for line in code ] + + pretty = [] + level = 0 + for n, line in enumerate(code): + if line in ('', '\n'): + pretty.append(line) + continue + level -= decrease[n] + pretty.append("%s%s" % (tab*level, line)) + level += increase[n] + return pretty + + +def julia_code(expr, assign_to=None, **settings): + r"""Converts `expr` to a string of Julia code. + + Parameters + ========== + + expr : Expr + A SymPy expression to be converted. + assign_to : optional + When given, the argument is used as the name of the variable to which + the expression is assigned. Can be a string, ``Symbol``, + ``MatrixSymbol``, or ``Indexed`` type. This can be helpful for + expressions that generate multi-line statements. + precision : integer, optional + The precision for numbers such as pi [default=16]. + user_functions : dict, optional + A dictionary where keys are ``FunctionClass`` instances and values are + their string representations. Alternatively, the dictionary value can + be a list of tuples i.e. [(argument_test, cfunction_string)]. See + below for examples. + human : bool, optional + If True, the result is a single string that may contain some constant + declarations for the number symbols. If False, the same information is + returned in a tuple of (symbols_to_declare, not_supported_functions, + code_text). [default=True]. + contract: bool, optional + If True, ``Indexed`` instances are assumed to obey tensor contraction + rules and the corresponding nested loops over indices are generated. + Setting contract=False will not generate loops, instead the user is + responsible to provide values for the indices in the code. + [default=True]. + inline: bool, optional + If True, we try to create single-statement code instead of multiple + statements. [default=True]. + + Examples + ======== + + >>> from sympy import julia_code, symbols, sin, pi + >>> x = symbols('x') + >>> julia_code(sin(x).series(x).removeO()) + 'x .^ 5 / 120 - x .^ 3 / 6 + x' + + >>> from sympy import Rational, ceiling + >>> x, y, tau = symbols("x, y, tau") + >>> julia_code((2*tau)**Rational(7, 2)) + '8 * sqrt(2) * tau .^ (7 // 2)' + + Note that element-wise (Hadamard) operations are used by default between + symbols. This is because its possible in Julia to write "vectorized" + code. It is harmless if the values are scalars. + + >>> julia_code(sin(pi*x*y), assign_to="s") + 's = sin(pi * x .* y)' + + If you need a matrix product "*" or matrix power "^", you can specify the + symbol as a ``MatrixSymbol``. + + >>> from sympy import Symbol, MatrixSymbol + >>> n = Symbol('n', integer=True, positive=True) + >>> A = MatrixSymbol('A', n, n) + >>> julia_code(3*pi*A**3) + '(3 * pi) * A ^ 3' + + This class uses several rules to decide which symbol to use a product. + Pure numbers use "*", Symbols use ".*" and MatrixSymbols use "*". + A HadamardProduct can be used to specify componentwise multiplication ".*" + of two MatrixSymbols. There is currently there is no easy way to specify + scalar symbols, so sometimes the code might have some minor cosmetic + issues. For example, suppose x and y are scalars and A is a Matrix, then + while a human programmer might write "(x^2*y)*A^3", we generate: + + >>> julia_code(x**2*y*A**3) + '(x .^ 2 .* y) * A ^ 3' + + Matrices are supported using Julia inline notation. When using + ``assign_to`` with matrices, the name can be specified either as a string + or as a ``MatrixSymbol``. The dimensions must align in the latter case. + + >>> from sympy import Matrix, MatrixSymbol + >>> mat = Matrix([[x**2, sin(x), ceiling(x)]]) + >>> julia_code(mat, assign_to='A') + 'A = [x .^ 2 sin(x) ceil(x)]' + + ``Piecewise`` expressions are implemented with logical masking by default. + Alternatively, you can pass "inline=False" to use if-else conditionals. + Note that if the ``Piecewise`` lacks a default term, represented by + ``(expr, True)`` then an error will be thrown. This is to prevent + generating an expression that may not evaluate to anything. + + >>> from sympy import Piecewise + >>> pw = Piecewise((x + 1, x > 0), (x, True)) + >>> julia_code(pw, assign_to=tau) + 'tau = ((x > 0) ? (x + 1) : (x))' + + Note that any expression that can be generated normally can also exist + inside a Matrix: + + >>> mat = Matrix([[x**2, pw, sin(x)]]) + >>> julia_code(mat, assign_to='A') + 'A = [x .^ 2 ((x > 0) ? (x + 1) : (x)) sin(x)]' + + Custom printing can be defined for certain types by passing a dictionary of + "type" : "function" to the ``user_functions`` kwarg. Alternatively, the + dictionary value can be a list of tuples i.e., [(argument_test, + cfunction_string)]. This can be used to call a custom Julia function. + + >>> from sympy import Function + >>> f = Function('f') + >>> g = Function('g') + >>> custom_functions = { + ... "f": "existing_julia_fcn", + ... "g": [(lambda x: x.is_Matrix, "my_mat_fcn"), + ... (lambda x: not x.is_Matrix, "my_fcn")] + ... } + >>> mat = Matrix([[1, x]]) + >>> julia_code(f(x) + g(x) + g(mat), user_functions=custom_functions) + 'existing_julia_fcn(x) + my_fcn(x) + my_mat_fcn([1 x])' + + Support for loops is provided through ``Indexed`` types. With + ``contract=True`` these expressions will be turned into loops, whereas + ``contract=False`` will just print the assignment expression that should be + looped over: + + >>> from sympy import Eq, IndexedBase, Idx + >>> len_y = 5 + >>> y = IndexedBase('y', shape=(len_y,)) + >>> t = IndexedBase('t', shape=(len_y,)) + >>> Dy = IndexedBase('Dy', shape=(len_y-1,)) + >>> i = Idx('i', len_y-1) + >>> e = Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i])) + >>> julia_code(e.rhs, assign_to=e.lhs, contract=False) + 'Dy[i] = (y[i + 1] - y[i]) ./ (t[i + 1] - t[i])' + """ + return JuliaCodePrinter(settings).doprint(expr, assign_to) + + +def print_julia_code(expr, **settings): + """Prints the Julia representation of the given expression. + + See `julia_code` for the meaning of the optional arguments. + """ + print(julia_code(expr, **settings)) diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/lambdarepr.py b/.venv/lib/python3.13/site-packages/sympy/printing/lambdarepr.py new file mode 100644 index 0000000000000000000000000000000000000000..87fa0988d138d54d68ab8aef1bbc0f27b243b472 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/lambdarepr.py @@ -0,0 +1,251 @@ +from .pycode import ( + PythonCodePrinter, + MpmathPrinter, +) +from .numpy import NumPyPrinter # NumPyPrinter is imported for backward compatibility +from sympy.core.sorting import default_sort_key + + +__all__ = [ + 'PythonCodePrinter', + 'MpmathPrinter', # MpmathPrinter is published for backward compatibility + 'NumPyPrinter', + 'LambdaPrinter', + 'NumPyPrinter', + 'IntervalPrinter', + 'lambdarepr', +] + + +class LambdaPrinter(PythonCodePrinter): + """ + This printer converts expressions into strings that can be used by + lambdify. + """ + printmethod = "_lambdacode" + + + def _print_And(self, expr): + result = ['('] + for arg in sorted(expr.args, key=default_sort_key): + result.extend(['(', self._print(arg), ')']) + result.append(' and ') + result = result[:-1] + result.append(')') + return ''.join(result) + + def _print_Or(self, expr): + result = ['('] + for arg in sorted(expr.args, key=default_sort_key): + result.extend(['(', self._print(arg), ')']) + result.append(' or ') + result = result[:-1] + result.append(')') + return ''.join(result) + + def _print_Not(self, expr): + result = ['(', 'not (', self._print(expr.args[0]), '))'] + return ''.join(result) + + def _print_BooleanTrue(self, expr): + return "True" + + def _print_BooleanFalse(self, expr): + return "False" + + def _print_ITE(self, expr): + result = [ + '((', self._print(expr.args[1]), + ') if (', self._print(expr.args[0]), + ') else (', self._print(expr.args[2]), '))' + ] + return ''.join(result) + + def _print_NumberSymbol(self, expr): + return str(expr) + + def _print_Pow(self, expr, **kwargs): + # XXX Temporary workaround. Should Python math printer be + # isolated from PythonCodePrinter? + return super(PythonCodePrinter, self)._print_Pow(expr, **kwargs) + + +# numexpr works by altering the string passed to numexpr.evaluate +# rather than by populating a namespace. Thus a special printer... +class NumExprPrinter(LambdaPrinter): + # key, value pairs correspond to SymPy name and numexpr name + # functions not appearing in this dict will raise a TypeError + printmethod = "_numexprcode" + + _numexpr_functions = { + 'sin' : 'sin', + 'cos' : 'cos', + 'tan' : 'tan', + 'asin': 'arcsin', + 'acos': 'arccos', + 'atan': 'arctan', + 'atan2' : 'arctan2', + 'sinh' : 'sinh', + 'cosh' : 'cosh', + 'tanh' : 'tanh', + 'asinh': 'arcsinh', + 'acosh': 'arccosh', + 'atanh': 'arctanh', + 'ln' : 'log', + 'log': 'log', + 'exp': 'exp', + 'sqrt' : 'sqrt', + 'Abs' : 'abs', + 'conjugate' : 'conj', + 'im' : 'imag', + 're' : 'real', + 'where' : 'where', + 'complex' : 'complex', + 'contains' : 'contains', + } + + module = 'numexpr' + + def _print_ImaginaryUnit(self, expr): + return '1j' + + def _print_seq(self, seq, delimiter=', '): + # simplified _print_seq taken from pretty.py + s = [self._print(item) for item in seq] + if s: + return delimiter.join(s) + else: + return "" + + def _print_Function(self, e): + func_name = e.func.__name__ + + nstr = self._numexpr_functions.get(func_name, None) + if nstr is None: + # check for implemented_function + if hasattr(e, '_imp_'): + return "(%s)" % self._print(e._imp_(*e.args)) + else: + raise TypeError("numexpr does not support function '%s'" % + func_name) + return "%s(%s)" % (nstr, self._print_seq(e.args)) + + def _print_Piecewise(self, expr): + "Piecewise function printer" + exprs = [self._print(arg.expr) for arg in expr.args] + conds = [self._print(arg.cond) for arg in expr.args] + # If [default_value, True] is a (expr, cond) sequence in a Piecewise object + # it will behave the same as passing the 'default' kwarg to select() + # *as long as* it is the last element in expr.args. + # If this is not the case, it may be triggered prematurely. + ans = [] + parenthesis_count = 0 + is_last_cond_True = False + for cond, expr in zip(conds, exprs): + if cond == 'True': + ans.append(expr) + is_last_cond_True = True + break + else: + ans.append('where(%s, %s, ' % (cond, expr)) + parenthesis_count += 1 + if not is_last_cond_True: + # See https://github.com/pydata/numexpr/issues/298 + # + # simplest way to put a nan but raises + # 'RuntimeWarning: invalid value encountered in log' + # + # There are other ways to do this such as + # + # >>> import numexpr as ne + # >>> nan = float('nan') + # >>> ne.evaluate('where(x < 0, -1, nan)', {'x': [-1, 2, 3], 'nan':nan}) + # array([-1., nan, nan]) + # + # That needs to be handled in the lambdified function though rather + # than here in the printer. + ans.append('log(-1)') + return ''.join(ans) + ')' * parenthesis_count + + def _print_ITE(self, expr): + from sympy.functions.elementary.piecewise import Piecewise + return self._print(expr.rewrite(Piecewise)) + + def blacklisted(self, expr): + raise TypeError("numexpr cannot be used with %s" % + expr.__class__.__name__) + + # blacklist all Matrix printing + _print_SparseRepMatrix = \ + _print_MutableSparseMatrix = \ + _print_ImmutableSparseMatrix = \ + _print_Matrix = \ + _print_DenseMatrix = \ + _print_MutableDenseMatrix = \ + _print_ImmutableMatrix = \ + _print_ImmutableDenseMatrix = \ + blacklisted + # blacklist some Python expressions + _print_list = \ + _print_tuple = \ + _print_Tuple = \ + _print_dict = \ + _print_Dict = \ + blacklisted + + def _print_NumExprEvaluate(self, expr): + evaluate = self._module_format(self.module +".evaluate") + return "%s('%s', truediv=True)" % (evaluate, self._print(expr.expr)) + + def doprint(self, expr): + from sympy.codegen.ast import CodegenAST + from sympy.codegen.pynodes import NumExprEvaluate + if not isinstance(expr, CodegenAST): + expr = NumExprEvaluate(expr) + return super().doprint(expr) + + def _print_Return(self, expr): + from sympy.codegen.pynodes import NumExprEvaluate + r, = expr.args + if not isinstance(r, NumExprEvaluate): + expr = expr.func(NumExprEvaluate(r)) + return super()._print_Return(expr) + + def _print_Assignment(self, expr): + from sympy.codegen.pynodes import NumExprEvaluate + lhs, rhs, *args = expr.args + if not isinstance(rhs, NumExprEvaluate): + expr = expr.func(lhs, NumExprEvaluate(rhs), *args) + return super()._print_Assignment(expr) + + def _print_CodeBlock(self, expr): + from sympy.codegen.ast import CodegenAST + from sympy.codegen.pynodes import NumExprEvaluate + args = [ arg if isinstance(arg, CodegenAST) else NumExprEvaluate(arg) for arg in expr.args ] + return super()._print_CodeBlock(self, expr.func(*args)) + + +class IntervalPrinter(MpmathPrinter, LambdaPrinter): + """Use ``lambda`` printer but print numbers as ``mpi`` intervals. """ + + def _print_Integer(self, expr): + return "mpi('%s')" % super(PythonCodePrinter, self)._print_Integer(expr) + + def _print_Rational(self, expr): + return "mpi('%s')" % super(PythonCodePrinter, self)._print_Rational(expr) + + def _print_Half(self, expr): + return "mpi('%s')" % super(PythonCodePrinter, self)._print_Rational(expr) + + def _print_Pow(self, expr): + return super(MpmathPrinter, self)._print_Pow(expr, rational=True) + + +for k in NumExprPrinter._numexpr_functions: + setattr(NumExprPrinter, '_print_%s' % k, NumExprPrinter._print_Function) + +def lambdarepr(expr, **settings): + """ + Returns a string usable for lambdifying. + """ + return LambdaPrinter(settings).doprint(expr) diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/latex.py b/.venv/lib/python3.13/site-packages/sympy/printing/latex.py new file mode 100644 index 0000000000000000000000000000000000000000..724df719d560e001deb175649a3769703bdf5ca5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/latex.py @@ -0,0 +1,3318 @@ +""" +A Printer which converts an expression into its LaTeX equivalent. +""" +from __future__ import annotations +from typing import Any, Callable, TYPE_CHECKING + +import itertools + +from sympy.core import Add, Float, Mod, Mul, Number, S, Symbol, Expr +from sympy.core.alphabets import greeks +from sympy.core.containers import Tuple +from sympy.core.function import Function, AppliedUndef, Derivative +from sympy.core.operations import AssocOp +from sympy.core.power import Pow +from sympy.core.sorting import default_sort_key +from sympy.core.sympify import SympifyError +from sympy.logic.boolalg import true, BooleanTrue, BooleanFalse + + +# sympy.printing imports +from sympy.printing.precedence import precedence_traditional +from sympy.printing.printer import Printer, print_function +from sympy.printing.conventions import split_super_sub, requires_partial +from sympy.printing.precedence import precedence, PRECEDENCE + +from mpmath.libmp.libmpf import prec_to_dps, to_str as mlib_to_str + +from sympy.utilities.iterables import has_variety, sift + +import re + +if TYPE_CHECKING: + from sympy.tensor.array import NDimArray + from sympy.vector.basisdependent import BasisDependent + +# Hand-picked functions which can be used directly in both LaTeX and MathJax +# Complete list at +# https://docs.mathjax.org/en/latest/tex.html#supported-latex-commands +# This variable only contains those functions which SymPy uses. +accepted_latex_functions = ['arcsin', 'arccos', 'arctan', 'sin', 'cos', 'tan', + 'sinh', 'cosh', 'tanh', 'sqrt', 'ln', 'log', 'sec', + 'csc', 'cot', 'coth', 're', 'im', 'frac', 'root', + 'arg', + ] + +tex_greek_dictionary = { + 'Alpha': r'\mathrm{A}', + 'Beta': r'\mathrm{B}', + 'Gamma': r'\Gamma', + 'Delta': r'\Delta', + 'Epsilon': r'\mathrm{E}', + 'Zeta': r'\mathrm{Z}', + 'Eta': r'\mathrm{H}', + 'Theta': r'\Theta', + 'Iota': r'\mathrm{I}', + 'Kappa': r'\mathrm{K}', + 'Lambda': r'\Lambda', + 'Mu': r'\mathrm{M}', + 'Nu': r'\mathrm{N}', + 'Xi': r'\Xi', + 'omicron': 'o', + 'Omicron': r'\mathrm{O}', + 'Pi': r'\Pi', + 'Rho': r'\mathrm{P}', + 'Sigma': r'\Sigma', + 'Tau': r'\mathrm{T}', + 'Upsilon': r'\Upsilon', + 'Phi': r'\Phi', + 'Chi': r'\mathrm{X}', + 'Psi': r'\Psi', + 'Omega': r'\Omega', + 'lamda': r'\lambda', + 'Lamda': r'\Lambda', + 'khi': r'\chi', + 'Khi': r'\mathrm{X}', + 'varepsilon': r'\varepsilon', + 'varkappa': r'\varkappa', + 'varphi': r'\varphi', + 'varpi': r'\varpi', + 'varrho': r'\varrho', + 'varsigma': r'\varsigma', + 'vartheta': r'\vartheta', +} + +other_symbols = {'aleph', 'beth', 'daleth', 'gimel', 'ell', 'eth', 'hbar', + 'hslash', 'mho', 'wp'} + +# Variable name modifiers +modifier_dict: dict[str, Callable[[str], str]] = { + # Accents + 'mathring': lambda s: r'\mathring{'+s+r'}', + 'ddddot': lambda s: r'\ddddot{'+s+r'}', + 'dddot': lambda s: r'\dddot{'+s+r'}', + 'ddot': lambda s: r'\ddot{'+s+r'}', + 'dot': lambda s: r'\dot{'+s+r'}', + 'check': lambda s: r'\check{'+s+r'}', + 'breve': lambda s: r'\breve{'+s+r'}', + 'acute': lambda s: r'\acute{'+s+r'}', + 'grave': lambda s: r'\grave{'+s+r'}', + 'tilde': lambda s: r'\tilde{'+s+r'}', + 'hat': lambda s: r'\hat{'+s+r'}', + 'bar': lambda s: r'\bar{'+s+r'}', + 'vec': lambda s: r'\vec{'+s+r'}', + 'prime': lambda s: "{"+s+"}'", + 'prm': lambda s: "{"+s+"}'", + # Faces + 'bold': lambda s: r'\boldsymbol{'+s+r'}', + 'bm': lambda s: r'\boldsymbol{'+s+r'}', + 'cal': lambda s: r'\mathcal{'+s+r'}', + 'scr': lambda s: r'\mathscr{'+s+r'}', + 'frak': lambda s: r'\mathfrak{'+s+r'}', + # Brackets + 'norm': lambda s: r'\left\|{'+s+r'}\right\|', + 'avg': lambda s: r'\left\langle{'+s+r'}\right\rangle', + 'abs': lambda s: r'\left|{'+s+r'}\right|', + 'mag': lambda s: r'\left|{'+s+r'}\right|', +} + +greek_letters_set = frozenset(greeks) + +_between_two_numbers_p = ( + re.compile(r'[0-9][} ]*$'), # search + re.compile(r'(\d|\\frac{\d+}{\d+})'), # match +) + + +def latex_escape(s: str) -> str: + """ + Escape a string such that latex interprets it as plaintext. + + We cannot use verbatim easily with mathjax, so escaping is easier. + Rules from https://tex.stackexchange.com/a/34586/41112. + """ + s = s.replace('\\', r'\textbackslash') + for c in '&%$#_{}': + s = s.replace(c, '\\' + c) + s = s.replace('~', r'\textasciitilde') + s = s.replace('^', r'\textasciicircum') + return s + + +class LatexPrinter(Printer): + printmethod = "_latex" + + _default_settings: dict[str, Any] = { + "full_prec": False, + "fold_frac_powers": False, + "fold_func_brackets": False, + "fold_short_frac": None, + "inv_trig_style": "abbreviated", + "itex": False, + "ln_notation": False, + "long_frac_ratio": None, + "mat_delim": "[", + "mat_str": None, + "mode": "plain", + "mul_symbol": None, + "order": None, + "symbol_names": {}, + "root_notation": True, + "mat_symbol_style": "plain", + "imaginary_unit": "i", + "gothic_re_im": False, + "decimal_separator": "period", + "perm_cyclic": True, + "parenthesize_super": True, + "min": None, + "max": None, + "diff_operator": "d", + "adjoint_style": "dagger", + "disable_split_super_sub": False, + } + + def __init__(self, settings=None): + Printer.__init__(self, settings) + + if 'mode' in self._settings: + valid_modes = ['inline', 'plain', 'equation', + 'equation*'] + if self._settings['mode'] not in valid_modes: + raise ValueError("'mode' must be one of 'inline', 'plain', " + "'equation' or 'equation*'") + + if self._settings['fold_short_frac'] is None and \ + self._settings['mode'] == 'inline': + self._settings['fold_short_frac'] = True + + mul_symbol_table = { + None: r" ", + "ldot": r" \,.\, ", + "dot": r" \cdot ", + "times": r" \times " + } + try: + self._settings['mul_symbol_latex'] = \ + mul_symbol_table[self._settings['mul_symbol']] + except KeyError: + self._settings['mul_symbol_latex'] = \ + self._settings['mul_symbol'] + try: + self._settings['mul_symbol_latex_numbers'] = \ + mul_symbol_table[self._settings['mul_symbol'] or 'dot'] + except KeyError: + if (self._settings['mul_symbol'].strip() in + ['', ' ', '\\', '\\,', '\\:', '\\;', '\\quad']): + self._settings['mul_symbol_latex_numbers'] = \ + mul_symbol_table['dot'] + else: + self._settings['mul_symbol_latex_numbers'] = \ + self._settings['mul_symbol'] + + self._delim_dict = {'(': ')', '[': ']'} + + imaginary_unit_table = { + None: r"i", + "i": r"i", + "ri": r"\mathrm{i}", + "ti": r"\text{i}", + "j": r"j", + "rj": r"\mathrm{j}", + "tj": r"\text{j}", + } + imag_unit = self._settings['imaginary_unit'] + self._settings['imaginary_unit_latex'] = imaginary_unit_table.get(imag_unit, imag_unit) + + diff_operator_table = { + None: r"d", + "d": r"d", + "rd": r"\mathrm{d}", + "td": r"\text{d}", + } + diff_operator = self._settings['diff_operator'] + self._settings["diff_operator_latex"] = diff_operator_table.get(diff_operator, diff_operator) + + def _add_parens(self, s) -> str: + return r"\left({}\right)".format(s) + + # TODO: merge this with the above, which requires a lot of test changes + def _add_parens_lspace(self, s) -> str: + return r"\left( {}\right)".format(s) + + def parenthesize(self, item, level, is_neg=False, strict=False) -> str: + prec_val = precedence_traditional(item) + if is_neg and strict: + return self._add_parens(self._print(item)) + + if (prec_val < level) or ((not strict) and prec_val <= level): + return self._add_parens(self._print(item)) + else: + return self._print(item) + + def parenthesize_super(self, s): + """ + Protect superscripts in s + + If the parenthesize_super option is set, protect with parentheses, else + wrap in braces. + """ + if "^" in s: + if self._settings['parenthesize_super']: + return self._add_parens(s) + else: + return "{{{}}}".format(s) + return s + + def doprint(self, expr) -> str: + tex = Printer.doprint(self, expr) + + if self._settings['mode'] == 'plain': + return tex + elif self._settings['mode'] == 'inline': + return r"$%s$" % tex + elif self._settings['itex']: + return r"$$%s$$" % tex + else: + env_str = self._settings['mode'] + return r"\begin{%s}%s\end{%s}" % (env_str, tex, env_str) + + def _needs_brackets(self, expr) -> bool: + """ + Returns True if the expression needs to be wrapped in brackets when + printed, False otherwise. For example: a + b => True; a => False; + 10 => False; -10 => True. + """ + return not ((expr.is_Integer and expr.is_nonnegative) + or (expr.is_Atom and (expr is not S.NegativeOne + and expr.is_Rational is False))) + + def _needs_function_brackets(self, expr) -> bool: + """ + Returns True if the expression needs to be wrapped in brackets when + passed as an argument to a function, False otherwise. This is a more + liberal version of _needs_brackets, in that many expressions which need + to be wrapped in brackets when added/subtracted/raised to a power do + not need them when passed to a function. Such an example is a*b. + """ + if not self._needs_brackets(expr): + return False + else: + # Muls of the form a*b*c... can be folded + if expr.is_Mul and not self._mul_is_clean(expr): + return True + # Pows which don't need brackets can be folded + elif expr.is_Pow and not self._pow_is_clean(expr): + return True + # Add and Function always need brackets + elif expr.is_Add or expr.is_Function: + return True + else: + return False + + def _needs_mul_brackets(self, expr, first=False, last=False) -> bool: + """ + Returns True if the expression needs to be wrapped in brackets when + printed as part of a Mul, False otherwise. This is True for Add, + but also for some container objects that would not need brackets + when appearing last in a Mul, e.g. an Integral. ``last=True`` + specifies that this expr is the last to appear in a Mul. + ``first=True`` specifies that this expr is the first to appear in + a Mul. + """ + from sympy.concrete.products import Product + from sympy.concrete.summations import Sum + from sympy.integrals.integrals import Integral + + if expr.is_Mul: + if not first and expr.could_extract_minus_sign(): + return True + elif precedence_traditional(expr) < PRECEDENCE["Mul"]: + return True + elif expr.is_Relational: + return True + if expr.is_Piecewise: + return True + if any(expr.has(x) for x in (Mod,)): + return True + if (not last and + any(expr.has(x) for x in (Integral, Product, Sum))): + return True + + return False + + def _needs_add_brackets(self, expr) -> bool: + """ + Returns True if the expression needs to be wrapped in brackets when + printed as part of an Add, False otherwise. This is False for most + things. + """ + if expr.is_Relational: + return True + if any(expr.has(x) for x in (Mod,)): + return True + if expr.is_Add: + return True + return False + + def _mul_is_clean(self, expr) -> bool: + for arg in expr.args: + if arg.is_Function: + return False + return True + + def _pow_is_clean(self, expr) -> bool: + return not self._needs_brackets(expr.base) + + def _do_exponent(self, expr: str, exp): + if exp is not None: + return r"\left(%s\right)^{%s}" % (expr, exp) + else: + return expr + + def _print_Basic(self, expr): + name = self._deal_with_super_sub(expr.__class__.__name__) + if expr.args: + ls = [self._print(o) for o in expr.args] + s = r"\operatorname{{{}}}\left({}\right)" + return s.format(name, ", ".join(ls)) + else: + return r"\text{{{}}}".format(name) + + def _print_bool(self, e: bool | BooleanTrue | BooleanFalse): + return r"\text{%s}" % e + + _print_BooleanTrue = _print_bool + _print_BooleanFalse = _print_bool + + def _print_NoneType(self, e): + return r"\text{%s}" % e + + def _print_Add(self, expr, order=None): + terms = self._as_ordered_terms(expr, order=order) + + tex = "" + for i, term in enumerate(terms): + if i == 0: + pass + elif term.could_extract_minus_sign(): + tex += " - " + term = -term + else: + tex += " + " + term_tex = self._print(term) + if self._needs_add_brackets(term): + term_tex = r"\left(%s\right)" % term_tex + tex += term_tex + + return tex + + def _print_Cycle(self, expr): + from sympy.combinatorics.permutations import Permutation + if expr.size == 0: + return r"\left( \right)" + expr = Permutation(expr) + expr_perm = expr.cyclic_form + siz = expr.size + if expr.array_form[-1] == siz - 1: + expr_perm = expr_perm + [[siz - 1]] + term_tex = '' + for i in expr_perm: + term_tex += str(i).replace(',', r"\;") + term_tex = term_tex.replace('[', r"\left( ") + term_tex = term_tex.replace(']', r"\right)") + return term_tex + + def _print_Permutation(self, expr): + from sympy.combinatorics.permutations import Permutation + from sympy.utilities.exceptions import sympy_deprecation_warning + + perm_cyclic = Permutation.print_cyclic + if perm_cyclic is not None: + sympy_deprecation_warning( + f""" + Setting Permutation.print_cyclic is deprecated. Instead use + init_printing(perm_cyclic={perm_cyclic}). + """, + deprecated_since_version="1.6", + active_deprecations_target="deprecated-permutation-print_cyclic", + stacklevel=8, + ) + else: + perm_cyclic = self._settings.get("perm_cyclic", True) + + if perm_cyclic: + return self._print_Cycle(expr) + + if expr.size == 0: + return r"\left( \right)" + + lower = [self._print(arg) for arg in expr.array_form] + upper = [self._print(arg) for arg in range(len(lower))] + + row1 = " & ".join(upper) + row2 = " & ".join(lower) + mat = r" \\ ".join((row1, row2)) + return r"\begin{pmatrix} %s \end{pmatrix}" % mat + + + def _print_AppliedPermutation(self, expr): + perm, var = expr.args + return r"\sigma_{%s}(%s)" % (self._print(perm), self._print(var)) + + def _print_Float(self, expr): + # Based off of that in StrPrinter + dps = prec_to_dps(expr._prec) + strip = False if self._settings['full_prec'] else True + low = self._settings["min"] if "min" in self._settings else None + high = self._settings["max"] if "max" in self._settings else None + str_real = mlib_to_str(expr._mpf_, dps, strip_zeros=strip, min_fixed=low, max_fixed=high) + + # Must always have a mul symbol (as 2.5 10^{20} just looks odd) + # thus we use the number separator + separator = self._settings['mul_symbol_latex_numbers'] + + if 'e' in str_real: + (mant, exp) = str_real.split('e') + + if exp[0] == '+': + exp = exp[1:] + if self._settings['decimal_separator'] == 'comma': + mant = mant.replace('.','{,}') + + return r"%s%s10^{%s}" % (mant, separator, exp) + elif str_real == "+inf": + return r"\infty" + elif str_real == "-inf": + return r"- \infty" + else: + if self._settings['decimal_separator'] == 'comma': + str_real = str_real.replace('.','{,}') + return str_real + + def _print_Cross(self, expr): + vec1 = expr._expr1 + vec2 = expr._expr2 + return r"%s \times %s" % (self.parenthesize(vec1, PRECEDENCE['Mul']), + self.parenthesize(vec2, PRECEDENCE['Mul'])) + + def _print_Curl(self, expr): + vec = expr._expr + return r"\nabla\times %s" % self.parenthesize(vec, PRECEDENCE['Mul']) + + def _print_Divergence(self, expr): + vec = expr._expr + return r"\nabla\cdot %s" % self.parenthesize(vec, PRECEDENCE['Mul']) + + def _print_Dot(self, expr): + vec1 = expr._expr1 + vec2 = expr._expr2 + return r"%s \cdot %s" % (self.parenthesize(vec1, PRECEDENCE['Mul']), + self.parenthesize(vec2, PRECEDENCE['Mul'])) + + def _print_Gradient(self, expr): + func = expr._expr + return r"\nabla %s" % self.parenthesize(func, PRECEDENCE['Mul']) + + def _print_Laplacian(self, expr): + func = expr._expr + return r"\Delta %s" % self.parenthesize(func, PRECEDENCE['Mul']) + + def _print_Mul(self, expr: Expr): + from sympy.simplify import fraction + separator: str = self._settings['mul_symbol_latex'] + numbersep: str = self._settings['mul_symbol_latex_numbers'] + + def convert(expr) -> str: + if not expr.is_Mul: + return str(self._print(expr)) + else: + if self.order not in ('old', 'none'): + args = expr.as_ordered_factors() + else: + args = list(expr.args) + + # If there are quantities or prefixes, append them at the back. + units, nonunits = sift(args, lambda x: (hasattr(x, "_scale_factor") or hasattr(x, "is_physical_constant")) or + (isinstance(x, Pow) and + hasattr(x.base, "is_physical_constant")), binary=True) + prefixes, units = sift(units, lambda x: hasattr(x, "_scale_factor"), binary=True) + return convert_args(nonunits + prefixes + units) + + def convert_args(args) -> str: + _tex = last_term_tex = "" + + for i, term in enumerate(args): + term_tex = self._print(term) + if not (hasattr(term, "_scale_factor") or hasattr(term, "is_physical_constant")): + if self._needs_mul_brackets(term, first=(i == 0), + last=(i == len(args) - 1)): + term_tex = r"\left(%s\right)" % term_tex + + if _between_two_numbers_p[0].search(last_term_tex) and \ + _between_two_numbers_p[1].match(term_tex): + # between two numbers + _tex += numbersep + elif _tex: + _tex += separator + elif _tex: + _tex += separator + + _tex += term_tex + last_term_tex = term_tex + return _tex + + # Check for unevaluated Mul. In this case we need to make sure the + # identities are visible, multiple Rational factors are not combined + # etc so we display in a straight-forward form that fully preserves all + # args and their order. + # XXX: _print_Pow calls this routine with instances of Pow... + if isinstance(expr, Mul): + args = expr.args + if args[0] is S.One or any(isinstance(arg, Number) for arg in args[1:]): + return convert_args(args) + + include_parens = False + if expr.could_extract_minus_sign(): + expr = -expr + tex = "- " + if expr.is_Add: + tex += "(" + include_parens = True + else: + tex = "" + + numer, denom = fraction(expr, exact=True) + + if denom is S.One and Pow(1, -1, evaluate=False) not in expr.args: + # use the original expression here, since fraction() may have + # altered it when producing numer and denom + tex += convert(expr) + + else: + snumer = convert(numer) + sdenom = convert(denom) + ldenom = len(sdenom.split()) + ratio = self._settings['long_frac_ratio'] + if self._settings['fold_short_frac'] and ldenom <= 2 and \ + "^" not in sdenom: + # handle short fractions + if self._needs_mul_brackets(numer, last=False): + tex += r"\left(%s\right) / %s" % (snumer, sdenom) + else: + tex += r"%s / %s" % (snumer, sdenom) + elif ratio is not None and \ + len(snumer.split()) > ratio*ldenom: + # handle long fractions + if self._needs_mul_brackets(numer, last=True): + tex += r"\frac{1}{%s}%s\left(%s\right)" \ + % (sdenom, separator, snumer) + elif numer.is_Mul: + # split a long numerator + a = S.One + b = S.One + for x in numer.args: + if self._needs_mul_brackets(x, last=False) or \ + len(convert(a*x).split()) > ratio*ldenom or \ + (b.is_commutative is x.is_commutative is False): + b *= x + else: + a *= x + if self._needs_mul_brackets(b, last=True): + tex += r"\frac{%s}{%s}%s\left(%s\right)" \ + % (convert(a), sdenom, separator, convert(b)) + else: + tex += r"\frac{%s}{%s}%s%s" \ + % (convert(a), sdenom, separator, convert(b)) + else: + tex += r"\frac{1}{%s}%s%s" % (sdenom, separator, snumer) + else: + tex += r"\frac{%s}{%s}" % (snumer, sdenom) + + if include_parens: + tex += ")" + return tex + + def _print_AlgebraicNumber(self, expr): + if expr.is_aliased: + return self._print(expr.as_poly().as_expr()) + else: + return self._print(expr.as_expr()) + + def _print_PrimeIdeal(self, expr): + p = self._print(expr.p) + if expr.is_inert: + return rf'\left({p}\right)' + alpha = self._print(expr.alpha.as_expr()) + return rf'\left({p}, {alpha}\right)' + + def _print_Pow(self, expr: Pow): + # Treat x**Rational(1,n) as special case + if expr.exp.is_Rational: + p: int = expr.exp.p # type: ignore + q: int = expr.exp.q # type: ignore + if abs(p) == 1 and q != 1 and self._settings['root_notation']: + base = self._print(expr.base) + if q == 2: + tex = r"\sqrt{%s}" % base + elif self._settings['itex']: + tex = r"\root{%d}{%s}" % (q, base) + else: + tex = r"\sqrt[%d]{%s}" % (q, base) + if expr.exp.is_negative: + return r"\frac{1}{%s}" % tex + else: + return tex + elif self._settings['fold_frac_powers'] and q != 1: + base = self.parenthesize(expr.base, PRECEDENCE['Pow']) + # issue #12886: add parentheses for superscripts raised to powers + if expr.base.is_Symbol: + base = self.parenthesize_super(base) + if expr.base.is_Function: + return self._print(expr.base, exp="%s/%s" % (p, q)) + return r"%s^{%s/%s}" % (base, p, q) + elif expr.exp.is_negative and expr.base.is_commutative: + # special case for 1^(-x), issue 9216 + if expr.base == 1: + return r"%s^{%s}" % (expr.base, expr.exp) + # special case for (1/x)^(-y) and (-1/-x)^(-y), issue 20252 + if expr.base.is_Rational: + base_p: int = expr.base.p # type: ignore + base_q: int = expr.base.q # type: ignore + if base_p * base_q == abs(base_q): + if expr.exp == -1: + return r"\frac{1}{\frac{%s}{%s}}" % (base_p, base_q) + else: + return r"\frac{1}{(\frac{%s}{%s})^{%s}}" % (base_p, base_q, abs(expr.exp)) + # things like 1/x + return self._print_Mul(expr) + if expr.base.is_Function: + return self._print(expr.base, exp=self._print(expr.exp)) + tex = r"%s^{%s}" + return self._helper_print_standard_power(expr, tex) + + def _helper_print_standard_power(self, expr, template: str) -> str: + exp = self._print(expr.exp) + # issue #12886: add parentheses around superscripts raised + # to powers + base = self.parenthesize(expr.base, PRECEDENCE['Pow']) + if expr.base.is_Symbol: + base = self.parenthesize_super(base) + elif expr.base.is_Float: + base = r"{%s}" % base + elif (isinstance(expr.base, Derivative) + and base.startswith(r'\left(') + and re.match(r'\\left\(\\d?d?dot', base) + and base.endswith(r'\right)')): + # don't use parentheses around dotted derivative + base = base[6: -7] # remove outermost added parens + return template % (base, exp) + + def _print_UnevaluatedExpr(self, expr): + return self._print(expr.args[0]) + + def _print_Sum(self, expr): + if len(expr.limits) == 1: + tex = r"\sum_{%s=%s}^{%s} " % \ + tuple([self._print(i) for i in expr.limits[0]]) + else: + def _format_ineq(l): + return r"%s \leq %s \leq %s" % \ + tuple([self._print(s) for s in (l[1], l[0], l[2])]) + + tex = r"\sum_{\substack{%s}} " % \ + str.join('\\\\', [_format_ineq(l) for l in expr.limits]) + + if isinstance(expr.function, Add): + tex += r"\left(%s\right)" % self._print(expr.function) + else: + tex += self._print(expr.function) + + return tex + + def _print_Product(self, expr): + if len(expr.limits) == 1: + tex = r"\prod_{%s=%s}^{%s} " % \ + tuple([self._print(i) for i in expr.limits[0]]) + else: + def _format_ineq(l): + return r"%s \leq %s \leq %s" % \ + tuple([self._print(s) for s in (l[1], l[0], l[2])]) + + tex = r"\prod_{\substack{%s}} " % \ + str.join('\\\\', [_format_ineq(l) for l in expr.limits]) + + if isinstance(expr.function, Add): + tex += r"\left(%s\right)" % self._print(expr.function) + else: + tex += self._print(expr.function) + + return tex + + def _print_BasisDependent(self, expr: 'BasisDependent'): + from sympy.vector import Vector + + o1: list[str] = [] + if expr == expr.zero: + return expr.zero._latex_form + if isinstance(expr, Vector): + items = expr.separate().items() + else: + items = [(0, expr)] + + for system, vect in items: + inneritems = list(vect.components.items()) + inneritems.sort(key=lambda x: x[0].__str__()) + for k, v in inneritems: + if v == 1: + o1.append(' + ' + k._latex_form) + elif v == -1: + o1.append(' - ' + k._latex_form) + else: + arg_str = r'\left(' + self._print(v) + r'\right)' + o1.append(' + ' + arg_str + k._latex_form) + + outstr = (''.join(o1)) + if outstr[1] != '-': + outstr = outstr[3:] + else: + outstr = outstr[1:] + return outstr + + def _print_Indexed(self, expr): + tex_base = self._print(expr.base) + tex = '{'+tex_base+'}'+'_{%s}' % ','.join( + map(self._print, expr.indices)) + return tex + + def _print_IndexedBase(self, expr): + return self._print(expr.label) + + def _print_Idx(self, expr): + label = self._print(expr.label) + if expr.upper is not None: + upper = self._print(expr.upper) + if expr.lower is not None: + lower = self._print(expr.lower) + else: + lower = self._print(S.Zero) + interval = '{lower}\\mathrel{{..}}\\nobreak {upper}'.format( + lower = lower, upper = upper) + return '{{{label}}}_{{{interval}}}'.format( + label = label, interval = interval) + #if no bounds are defined this just prints the label + return label + + def _print_Derivative(self, expr): + if requires_partial(expr.expr): + diff_symbol = r'\partial' + else: + diff_symbol = self._settings["diff_operator_latex"] + + tex = "" + dim = 0 + for x, num in reversed(expr.variable_count): + dim += num + if num == 1: + tex += r"%s %s" % (diff_symbol, self._print(x)) + else: + tex += r"%s %s^{%s}" % (diff_symbol, + self.parenthesize_super(self._print(x)), + self._print(num)) + + if dim == 1: + tex = r"\frac{%s}{%s}" % (diff_symbol, tex) + else: + tex = r"\frac{%s^{%s}}{%s}" % (diff_symbol, self._print(dim), tex) + + if any(i.could_extract_minus_sign() for i in expr.args): + return r"%s %s" % (tex, self.parenthesize(expr.expr, + PRECEDENCE["Mul"], + is_neg=True, + strict=True)) + + return r"%s %s" % (tex, self.parenthesize(expr.expr, + PRECEDENCE["Mul"], + is_neg=False, + strict=True)) + + def _print_Subs(self, subs): + expr, old, new = subs.args + latex_expr = self._print(expr) + latex_old = (self._print(e) for e in old) + latex_new = (self._print(e) for e in new) + latex_subs = r'\\ '.join( + e[0] + '=' + e[1] for e in zip(latex_old, latex_new)) + return r'\left. %s \right|_{\substack{ %s }}' % (latex_expr, + latex_subs) + + def _print_Integral(self, expr): + tex, symbols = "", [] + diff_symbol = self._settings["diff_operator_latex"] + + # Only up to \iiiint exists + if len(expr.limits) <= 4 and all(len(lim) == 1 for lim in expr.limits): + # Use len(expr.limits)-1 so that syntax highlighters don't think + # \" is an escaped quote + tex = r"\i" + "i"*(len(expr.limits) - 1) + "nt" + symbols = [r"\, %s%s" % (diff_symbol, self._print(symbol[0])) + for symbol in expr.limits] + + else: + for lim in reversed(expr.limits): + symbol = lim[0] + tex += r"\int" + + if len(lim) > 1: + if self._settings['mode'] != 'inline' \ + and not self._settings['itex']: + tex += r"\limits" + + if len(lim) == 3: + tex += "_{%s}^{%s}" % (self._print(lim[1]), + self._print(lim[2])) + if len(lim) == 2: + tex += "^{%s}" % (self._print(lim[1])) + + symbols.insert(0, r"\, %s%s" % (diff_symbol, self._print(symbol))) + + return r"%s %s%s" % (tex, self.parenthesize(expr.function, + PRECEDENCE["Mul"], + is_neg=any(i.could_extract_minus_sign() for i in expr.args), + strict=True), + "".join(symbols)) + + def _print_Limit(self, expr): + e, z, z0, dir = expr.args + + tex = r"\lim_{%s \to " % self._print(z) + if str(dir) == '+-' or z0 in (S.Infinity, S.NegativeInfinity): + tex += r"%s}" % self._print(z0) + else: + tex += r"%s^%s}" % (self._print(z0), self._print(dir)) + + if isinstance(e, AssocOp): + return r"%s\left(%s\right)" % (tex, self._print(e)) + else: + return r"%s %s" % (tex, self._print(e)) + + def _hprint_Function(self, func: str) -> str: + r''' + Logic to decide how to render a function to latex + - if it is a recognized latex name, use the appropriate latex command + - if it is a single letter, excluding sub- and superscripts, just use that letter + - if it is a longer name, then put \operatorname{} around it and be + mindful of undercores in the name + ''' + func = self._deal_with_super_sub(func) + superscriptidx = func.find("^") + subscriptidx = func.find("_") + if func in accepted_latex_functions: + name = r"\%s" % func + elif len(func) == 1 or func.startswith('\\') or subscriptidx == 1 or superscriptidx == 1: + name = func + else: + if superscriptidx > 0 and subscriptidx > 0: + name = r"\operatorname{%s}%s" %( + func[:min(subscriptidx,superscriptidx)], + func[min(subscriptidx,superscriptidx):]) + elif superscriptidx > 0: + name = r"\operatorname{%s}%s" %( + func[:superscriptidx], + func[superscriptidx:]) + elif subscriptidx > 0: + name = r"\operatorname{%s}%s" %( + func[:subscriptidx], + func[subscriptidx:]) + else: + name = r"\operatorname{%s}" % func + return name + + def _print_Function(self, expr: Function, exp=None) -> str: + r''' + Render functions to LaTeX, handling functions that LaTeX knows about + e.g., sin, cos, ... by using the proper LaTeX command (\sin, \cos, ...). + For single-letter function names, render them as regular LaTeX math + symbols. For multi-letter function names that LaTeX does not know + about, (e.g., Li, sech) use \operatorname{} so that the function name + is rendered in Roman font and LaTeX handles spacing properly. + + expr is the expression involving the function + exp is an exponent + ''' + func = expr.func.__name__ + if hasattr(self, '_print_' + func) and \ + not isinstance(expr, AppliedUndef): + return getattr(self, '_print_' + func)(expr, exp) + else: + args = [str(self._print(arg)) for arg in expr.args] + # How inverse trig functions should be displayed, formats are: + # abbreviated: asin, full: arcsin, power: sin^-1 + inv_trig_style = self._settings['inv_trig_style'] + # If we are dealing with a power-style inverse trig function + inv_trig_power_case = False + # If it is applicable to fold the argument brackets + can_fold_brackets = self._settings['fold_func_brackets'] and \ + len(args) == 1 and \ + not self._needs_function_brackets(expr.args[0]) + + inv_trig_table = [ + "asin", "acos", "atan", + "acsc", "asec", "acot", + "asinh", "acosh", "atanh", + "acsch", "asech", "acoth", + ] + + # If the function is an inverse trig function, handle the style + if func in inv_trig_table: + if inv_trig_style == "abbreviated": + pass + elif inv_trig_style == "full": + func = ("ar" if func[-1] == "h" else "arc") + func[1:] + elif inv_trig_style == "power": + func = func[1:] + inv_trig_power_case = True + + # Can never fold brackets if we're raised to a power + if exp is not None: + can_fold_brackets = False + + if inv_trig_power_case: + if func in accepted_latex_functions: + name = r"\%s^{-1}" % func + else: + name = r"\operatorname{%s}^{-1}" % func + elif exp is not None: + func_tex = self._hprint_Function(func) + func_tex = self.parenthesize_super(func_tex) + name = r'%s^{%s}' % (func_tex, exp) + else: + name = self._hprint_Function(func) + + if can_fold_brackets: + if func in accepted_latex_functions: + # Wrap argument safely to avoid parse-time conflicts + # with the function name itself + name += r" {%s}" + else: + name += r"%s" + else: + name += r"{\left(%s \right)}" + + if inv_trig_power_case and exp is not None: + name += r"^{%s}" % exp + + return name % ",".join(args) + + def _print_UndefinedFunction(self, expr): + return self._hprint_Function(str(expr)) + + def _print_ElementwiseApplyFunction(self, expr): + return r"{%s}_{\circ}\left({%s}\right)" % ( + self._print(expr.function), + self._print(expr.expr), + ) + + @property + def _special_function_classes(self): + from sympy.functions.special.tensor_functions import KroneckerDelta + from sympy.functions.special.gamma_functions import gamma, lowergamma + from sympy.functions.special.beta_functions import beta + from sympy.functions.special.delta_functions import DiracDelta + from sympy.functions.special.error_functions import Chi + return {KroneckerDelta: r'\delta', + gamma: r'\Gamma', + lowergamma: r'\gamma', + beta: r'\operatorname{B}', + DiracDelta: r'\delta', + Chi: r'\operatorname{Chi}'} + + def _print_FunctionClass(self, expr): + for cls in self._special_function_classes: + if issubclass(expr, cls) and expr.__name__ == cls.__name__: + return self._special_function_classes[cls] + return self._hprint_Function(str(expr)) + + def _print_Lambda(self, expr): + symbols, expr = expr.args + + if len(symbols) == 1: + symbols = self._print(symbols[0]) + else: + symbols = self._print(tuple(symbols)) + + tex = r"\left( %s \mapsto %s \right)" % (symbols, self._print(expr)) + + return tex + + def _print_IdentityFunction(self, expr): + return r"\left( x \mapsto x \right)" + + def _hprint_variadic_function(self, expr, exp=None) -> str: + args = sorted(expr.args, key=default_sort_key) + texargs = [r"%s" % self._print(symbol) for symbol in args] + tex = r"\%s\left(%s\right)" % (str(expr.func).lower(), + ", ".join(texargs)) + if exp is not None: + return r"%s^{%s}" % (tex, exp) + else: + return tex + + _print_Min = _print_Max = _hprint_variadic_function + + def _print_floor(self, expr, exp=None): + tex = r"\left\lfloor{%s}\right\rfloor" % self._print(expr.args[0]) + + if exp is not None: + return r"%s^{%s}" % (tex, exp) + else: + return tex + + def _print_ceiling(self, expr, exp=None): + tex = r"\left\lceil{%s}\right\rceil" % self._print(expr.args[0]) + + if exp is not None: + return r"%s^{%s}" % (tex, exp) + else: + return tex + + def _print_log(self, expr, exp=None): + if not self._settings["ln_notation"]: + tex = r"\log{\left(%s \right)}" % self._print(expr.args[0]) + else: + tex = r"\ln{\left(%s \right)}" % self._print(expr.args[0]) + + if exp is not None: + return r"%s^{%s}" % (tex, exp) + else: + return tex + + def _print_Abs(self, expr, exp=None): + tex = r"\left|{%s}\right|" % self._print(expr.args[0]) + + if exp is not None: + return r"%s^{%s}" % (tex, exp) + else: + return tex + + def _print_re(self, expr, exp=None): + if self._settings['gothic_re_im']: + tex = r"\Re{%s}" % self.parenthesize(expr.args[0], PRECEDENCE['Atom']) + else: + tex = r"\operatorname{{re}}{{{}}}".format(self.parenthesize(expr.args[0], PRECEDENCE['Atom'])) + + return self._do_exponent(tex, exp) + + def _print_im(self, expr, exp=None): + if self._settings['gothic_re_im']: + tex = r"\Im{%s}" % self.parenthesize(expr.args[0], PRECEDENCE['Atom']) + else: + tex = r"\operatorname{{im}}{{{}}}".format(self.parenthesize(expr.args[0], PRECEDENCE['Atom'])) + + return self._do_exponent(tex, exp) + + def _print_Not(self, e): + from sympy.logic.boolalg import (Equivalent, Implies) + if isinstance(e.args[0], Equivalent): + return self._print_Equivalent(e.args[0], r"\not\Leftrightarrow") + if isinstance(e.args[0], Implies): + return self._print_Implies(e.args[0], r"\not\Rightarrow") + if (e.args[0].is_Boolean): + return r"\neg \left(%s\right)" % self._print(e.args[0]) + else: + return r"\neg %s" % self._print(e.args[0]) + + def _print_LogOp(self, args, char): + arg = args[0] + if arg.is_Boolean and not arg.is_Not: + tex = r"\left(%s\right)" % self._print(arg) + else: + tex = r"%s" % self._print(arg) + + for arg in args[1:]: + if arg.is_Boolean and not arg.is_Not: + tex += r" %s \left(%s\right)" % (char, self._print(arg)) + else: + tex += r" %s %s" % (char, self._print(arg)) + + return tex + + def _print_And(self, e): + args = sorted(e.args, key=default_sort_key) + return self._print_LogOp(args, r"\wedge") + + def _print_Or(self, e): + args = sorted(e.args, key=default_sort_key) + return self._print_LogOp(args, r"\vee") + + def _print_Xor(self, e): + args = sorted(e.args, key=default_sort_key) + return self._print_LogOp(args, r"\veebar") + + def _print_Implies(self, e, altchar=None): + return self._print_LogOp(e.args, altchar or r"\Rightarrow") + + def _print_Equivalent(self, e, altchar=None): + args = sorted(e.args, key=default_sort_key) + return self._print_LogOp(args, altchar or r"\Leftrightarrow") + + def _print_conjugate(self, expr, exp=None): + tex = r"\overline{%s}" % self._print(expr.args[0]) + + if exp is not None: + return r"%s^{%s}" % (tex, exp) + else: + return tex + + def _print_polar_lift(self, expr, exp=None): + func = r"\operatorname{polar\_lift}" + arg = r"{\left(%s \right)}" % self._print(expr.args[0]) + + if exp is not None: + return r"%s^{%s}%s" % (func, exp, arg) + else: + return r"%s%s" % (func, arg) + + def _print_ExpBase(self, expr, exp=None): + # TODO should exp_polar be printed differently? + # what about exp_polar(0), exp_polar(1)? + tex = r"e^{%s}" % self._print(expr.args[0]) + return self._do_exponent(tex, exp) + + def _print_Exp1(self, expr, exp=None): + return "e" + + def _print_elliptic_k(self, expr, exp=None): + tex = r"\left(%s\right)" % self._print(expr.args[0]) + if exp is not None: + return r"K^{%s}%s" % (exp, tex) + else: + return r"K%s" % tex + + def _print_elliptic_f(self, expr, exp=None): + tex = r"\left(%s\middle| %s\right)" % \ + (self._print(expr.args[0]), self._print(expr.args[1])) + if exp is not None: + return r"F^{%s}%s" % (exp, tex) + else: + return r"F%s" % tex + + def _print_elliptic_e(self, expr, exp=None): + if len(expr.args) == 2: + tex = r"\left(%s\middle| %s\right)" % \ + (self._print(expr.args[0]), self._print(expr.args[1])) + else: + tex = r"\left(%s\right)" % self._print(expr.args[0]) + if exp is not None: + return r"E^{%s}%s" % (exp, tex) + else: + return r"E%s" % tex + + def _print_elliptic_pi(self, expr, exp=None): + if len(expr.args) == 3: + tex = r"\left(%s; %s\middle| %s\right)" % \ + (self._print(expr.args[0]), self._print(expr.args[1]), + self._print(expr.args[2])) + else: + tex = r"\left(%s\middle| %s\right)" % \ + (self._print(expr.args[0]), self._print(expr.args[1])) + if exp is not None: + return r"\Pi^{%s}%s" % (exp, tex) + else: + return r"\Pi%s" % tex + + def _print_beta(self, expr, exp=None): + x = expr.args[0] + # Deal with unevaluated single argument beta + y = expr.args[0] if len(expr.args) == 1 else expr.args[1] + tex = rf"\left({x}, {y}\right)" + + if exp is not None: + return r"\operatorname{B}^{%s}%s" % (exp, tex) + else: + return r"\operatorname{B}%s" % tex + + def _print_betainc(self, expr, exp=None, operator='B'): + largs = [self._print(arg) for arg in expr.args] + tex = r"\left(%s, %s\right)" % (largs[0], largs[1]) + + if exp is not None: + return r"\operatorname{%s}_{(%s, %s)}^{%s}%s" % (operator, largs[2], largs[3], exp, tex) + else: + return r"\operatorname{%s}_{(%s, %s)}%s" % (operator, largs[2], largs[3], tex) + + def _print_betainc_regularized(self, expr, exp=None): + return self._print_betainc(expr, exp, operator='I') + + def _print_uppergamma(self, expr, exp=None): + tex = r"\left(%s, %s\right)" % (self._print(expr.args[0]), + self._print(expr.args[1])) + + if exp is not None: + return r"\Gamma^{%s}%s" % (exp, tex) + else: + return r"\Gamma%s" % tex + + def _print_lowergamma(self, expr, exp=None): + tex = r"\left(%s, %s\right)" % (self._print(expr.args[0]), + self._print(expr.args[1])) + + if exp is not None: + return r"\gamma^{%s}%s" % (exp, tex) + else: + return r"\gamma%s" % tex + + def _hprint_one_arg_func(self, expr, exp=None) -> str: + tex = r"\left(%s\right)" % self._print(expr.args[0]) + + if exp is not None: + return r"%s^{%s}%s" % (self._print(expr.func), exp, tex) + else: + return r"%s%s" % (self._print(expr.func), tex) + + _print_gamma = _hprint_one_arg_func + + def _print_Chi(self, expr, exp=None): + tex = r"\left(%s\right)" % self._print(expr.args[0]) + + if exp is not None: + return r"\operatorname{Chi}^{%s}%s" % (exp, tex) + else: + return r"\operatorname{Chi}%s" % tex + + def _print_expint(self, expr, exp=None): + tex = r"\left(%s\right)" % self._print(expr.args[1]) + nu = self._print(expr.args[0]) + + if exp is not None: + return r"\operatorname{E}_{%s}^{%s}%s" % (nu, exp, tex) + else: + return r"\operatorname{E}_{%s}%s" % (nu, tex) + + def _print_fresnels(self, expr, exp=None): + tex = r"\left(%s\right)" % self._print(expr.args[0]) + + if exp is not None: + return r"S^{%s}%s" % (exp, tex) + else: + return r"S%s" % tex + + def _print_fresnelc(self, expr, exp=None): + tex = r"\left(%s\right)" % self._print(expr.args[0]) + + if exp is not None: + return r"C^{%s}%s" % (exp, tex) + else: + return r"C%s" % tex + + def _print_subfactorial(self, expr, exp=None): + tex = r"!%s" % self.parenthesize(expr.args[0], PRECEDENCE["Func"]) + + if exp is not None: + return r"\left(%s\right)^{%s}" % (tex, exp) + else: + return tex + + def _print_factorial(self, expr, exp=None): + tex = r"%s!" % self.parenthesize(expr.args[0], PRECEDENCE["Func"]) + + if exp is not None: + return r"%s^{%s}" % (tex, exp) + else: + return tex + + def _print_factorial2(self, expr, exp=None): + tex = r"%s!!" % self.parenthesize(expr.args[0], PRECEDENCE["Func"]) + + if exp is not None: + return r"%s^{%s}" % (tex, exp) + else: + return tex + + def _print_binomial(self, expr, exp=None): + tex = r"{\binom{%s}{%s}}" % (self._print(expr.args[0]), + self._print(expr.args[1])) + + if exp is not None: + return r"%s^{%s}" % (tex, exp) + else: + return tex + + def _print_RisingFactorial(self, expr, exp=None): + n, k = expr.args + base = r"%s" % self.parenthesize(n, PRECEDENCE['Func']) + + tex = r"{%s}^{\left(%s\right)}" % (base, self._print(k)) + + return self._do_exponent(tex, exp) + + def _print_FallingFactorial(self, expr, exp=None): + n, k = expr.args + sub = r"%s" % self.parenthesize(k, PRECEDENCE['Func']) + + tex = r"{\left(%s\right)}_{%s}" % (self._print(n), sub) + + return self._do_exponent(tex, exp) + + def _hprint_BesselBase(self, expr, exp, sym: str) -> str: + tex = r"%s" % (sym) + + need_exp = False + if exp is not None: + if tex.find('^') == -1: + tex = r"%s^{%s}" % (tex, exp) + else: + need_exp = True + + tex = r"%s_{%s}\left(%s\right)" % (tex, self._print(expr.order), + self._print(expr.argument)) + + if need_exp: + tex = self._do_exponent(tex, exp) + return tex + + def _hprint_vec(self, vec) -> str: + if not vec: + return "" + s = "" + for i in vec[:-1]: + s += "%s, " % self._print(i) + s += self._print(vec[-1]) + return s + + def _print_besselj(self, expr, exp=None): + return self._hprint_BesselBase(expr, exp, 'J') + + def _print_besseli(self, expr, exp=None): + return self._hprint_BesselBase(expr, exp, 'I') + + def _print_besselk(self, expr, exp=None): + return self._hprint_BesselBase(expr, exp, 'K') + + def _print_bessely(self, expr, exp=None): + return self._hprint_BesselBase(expr, exp, 'Y') + + def _print_yn(self, expr, exp=None): + return self._hprint_BesselBase(expr, exp, 'y') + + def _print_jn(self, expr, exp=None): + return self._hprint_BesselBase(expr, exp, 'j') + + def _print_hankel1(self, expr, exp=None): + return self._hprint_BesselBase(expr, exp, 'H^{(1)}') + + def _print_hankel2(self, expr, exp=None): + return self._hprint_BesselBase(expr, exp, 'H^{(2)}') + + def _print_hn1(self, expr, exp=None): + return self._hprint_BesselBase(expr, exp, 'h^{(1)}') + + def _print_hn2(self, expr, exp=None): + return self._hprint_BesselBase(expr, exp, 'h^{(2)}') + + def _hprint_airy(self, expr, exp=None, notation="") -> str: + tex = r"\left(%s\right)" % self._print(expr.args[0]) + + if exp is not None: + return r"%s^{%s}%s" % (notation, exp, tex) + else: + return r"%s%s" % (notation, tex) + + def _hprint_airy_prime(self, expr, exp=None, notation="") -> str: + tex = r"\left(%s\right)" % self._print(expr.args[0]) + + if exp is not None: + return r"{%s^\prime}^{%s}%s" % (notation, exp, tex) + else: + return r"%s^\prime%s" % (notation, tex) + + def _print_airyai(self, expr, exp=None): + return self._hprint_airy(expr, exp, 'Ai') + + def _print_airybi(self, expr, exp=None): + return self._hprint_airy(expr, exp, 'Bi') + + def _print_airyaiprime(self, expr, exp=None): + return self._hprint_airy_prime(expr, exp, 'Ai') + + def _print_airybiprime(self, expr, exp=None): + return self._hprint_airy_prime(expr, exp, 'Bi') + + def _print_hyper(self, expr, exp=None): + tex = r"{{}_{%s}F_{%s}\left(\begin{matrix} %s \\ %s \end{matrix}" \ + r"\middle| {%s} \right)}" % \ + (self._print(len(expr.ap)), self._print(len(expr.bq)), + self._hprint_vec(expr.ap), self._hprint_vec(expr.bq), + self._print(expr.argument)) + + if exp is not None: + tex = r"{%s}^{%s}" % (tex, exp) + return tex + + def _print_meijerg(self, expr, exp=None): + tex = r"{G_{%s, %s}^{%s, %s}\left(\begin{matrix} %s & %s \\" \ + r"%s & %s \end{matrix} \middle| {%s} \right)}" % \ + (self._print(len(expr.ap)), self._print(len(expr.bq)), + self._print(len(expr.bm)), self._print(len(expr.an)), + self._hprint_vec(expr.an), self._hprint_vec(expr.aother), + self._hprint_vec(expr.bm), self._hprint_vec(expr.bother), + self._print(expr.argument)) + + if exp is not None: + tex = r"{%s}^{%s}" % (tex, exp) + return tex + + def _print_dirichlet_eta(self, expr, exp=None): + tex = r"\left(%s\right)" % self._print(expr.args[0]) + if exp is not None: + return r"\eta^{%s}%s" % (exp, tex) + return r"\eta%s" % tex + + def _print_zeta(self, expr, exp=None): + if len(expr.args) == 2: + tex = r"\left(%s, %s\right)" % tuple(map(self._print, expr.args)) + else: + tex = r"\left(%s\right)" % self._print(expr.args[0]) + if exp is not None: + return r"\zeta^{%s}%s" % (exp, tex) + return r"\zeta%s" % tex + + def _print_stieltjes(self, expr, exp=None): + if len(expr.args) == 2: + tex = r"_{%s}\left(%s\right)" % tuple(map(self._print, expr.args)) + else: + tex = r"_{%s}" % self._print(expr.args[0]) + if exp is not None: + return r"\gamma%s^{%s}" % (tex, exp) + return r"\gamma%s" % tex + + def _print_lerchphi(self, expr, exp=None): + tex = r"\left(%s, %s, %s\right)" % tuple(map(self._print, expr.args)) + if exp is None: + return r"\Phi%s" % tex + return r"\Phi^{%s}%s" % (exp, tex) + + def _print_polylog(self, expr, exp=None): + s, z = map(self._print, expr.args) + tex = r"\left(%s\right)" % z + if exp is None: + return r"\operatorname{Li}_{%s}%s" % (s, tex) + return r"\operatorname{Li}_{%s}^{%s}%s" % (s, exp, tex) + + def _print_jacobi(self, expr, exp=None): + n, a, b, x = map(self._print, expr.args) + tex = r"P_{%s}^{\left(%s,%s\right)}\left(%s\right)" % (n, a, b, x) + if exp is not None: + tex = r"\left(" + tex + r"\right)^{%s}" % (exp) + return tex + + def _print_gegenbauer(self, expr, exp=None): + n, a, x = map(self._print, expr.args) + tex = r"C_{%s}^{\left(%s\right)}\left(%s\right)" % (n, a, x) + if exp is not None: + tex = r"\left(" + tex + r"\right)^{%s}" % (exp) + return tex + + def _print_chebyshevt(self, expr, exp=None): + n, x = map(self._print, expr.args) + tex = r"T_{%s}\left(%s\right)" % (n, x) + if exp is not None: + tex = r"\left(" + tex + r"\right)^{%s}" % (exp) + return tex + + def _print_chebyshevu(self, expr, exp=None): + n, x = map(self._print, expr.args) + tex = r"U_{%s}\left(%s\right)" % (n, x) + if exp is not None: + tex = r"\left(" + tex + r"\right)^{%s}" % (exp) + return tex + + def _print_legendre(self, expr, exp=None): + n, x = map(self._print, expr.args) + tex = r"P_{%s}\left(%s\right)" % (n, x) + if exp is not None: + tex = r"\left(" + tex + r"\right)^{%s}" % (exp) + return tex + + def _print_assoc_legendre(self, expr, exp=None): + n, a, x = map(self._print, expr.args) + tex = r"P_{%s}^{\left(%s\right)}\left(%s\right)" % (n, a, x) + if exp is not None: + tex = r"\left(" + tex + r"\right)^{%s}" % (exp) + return tex + + def _print_hermite(self, expr, exp=None): + n, x = map(self._print, expr.args) + tex = r"H_{%s}\left(%s\right)" % (n, x) + if exp is not None: + tex = r"\left(" + tex + r"\right)^{%s}" % (exp) + return tex + + def _print_laguerre(self, expr, exp=None): + n, x = map(self._print, expr.args) + tex = r"L_{%s}\left(%s\right)" % (n, x) + if exp is not None: + tex = r"\left(" + tex + r"\right)^{%s}" % (exp) + return tex + + def _print_assoc_laguerre(self, expr, exp=None): + n, a, x = map(self._print, expr.args) + tex = r"L_{%s}^{\left(%s\right)}\left(%s\right)" % (n, a, x) + if exp is not None: + tex = r"\left(" + tex + r"\right)^{%s}" % (exp) + return tex + + def _print_Ynm(self, expr, exp=None): + n, m, theta, phi = map(self._print, expr.args) + tex = r"Y_{%s}^{%s}\left(%s,%s\right)" % (n, m, theta, phi) + if exp is not None: + tex = r"\left(" + tex + r"\right)^{%s}" % (exp) + return tex + + def _print_Znm(self, expr, exp=None): + n, m, theta, phi = map(self._print, expr.args) + tex = r"Z_{%s}^{%s}\left(%s,%s\right)" % (n, m, theta, phi) + if exp is not None: + tex = r"\left(" + tex + r"\right)^{%s}" % (exp) + return tex + + def __print_mathieu_functions(self, character, args, prime=False, exp=None): + a, q, z = map(self._print, args) + sup = r"^{\prime}" if prime else "" + exp = "" if not exp else "^{%s}" % exp + return r"%s%s\left(%s, %s, %s\right)%s" % (character, sup, a, q, z, exp) + + def _print_mathieuc(self, expr, exp=None): + return self.__print_mathieu_functions("C", expr.args, exp=exp) + + def _print_mathieus(self, expr, exp=None): + return self.__print_mathieu_functions("S", expr.args, exp=exp) + + def _print_mathieucprime(self, expr, exp=None): + return self.__print_mathieu_functions("C", expr.args, prime=True, exp=exp) + + def _print_mathieusprime(self, expr, exp=None): + return self.__print_mathieu_functions("S", expr.args, prime=True, exp=exp) + + def _print_Rational(self, expr): + if expr.q != 1: + sign = "" + p = expr.p + if expr.p < 0: + sign = "- " + p = -p + if self._settings['fold_short_frac']: + return r"%s%d / %d" % (sign, p, expr.q) + return r"%s\frac{%d}{%d}" % (sign, p, expr.q) + else: + return self._print(expr.p) + + def _print_Order(self, expr): + s = self._print(expr.expr) + if expr.point and any(p != S.Zero for p in expr.point) or \ + len(expr.variables) > 1: + s += '; ' + if len(expr.variables) > 1: + s += self._print(expr.variables) + elif expr.variables: + s += self._print(expr.variables[0]) + s += r'\rightarrow ' + if len(expr.point) > 1: + s += self._print(expr.point) + else: + s += self._print(expr.point[0]) + return r"O\left(%s\right)" % s + + def _print_Symbol(self, expr: Symbol, style='plain'): + name: str = self._settings['symbol_names'].get(expr) + if name is not None: + return name + + return self._deal_with_super_sub(expr.name, style=style) + + _print_RandomSymbol = _print_Symbol + + def _split_super_sub(self, name: str) -> tuple[str, list[str], list[str]]: + if name is None or '{' in name: + return (name, [], []) + elif self._settings["disable_split_super_sub"]: + name, supers, subs = (name.replace('_', '\\_').replace('^', '\\^'), [], []) + else: + name, supers, subs = split_super_sub(name) + name = translate(name) + supers = [translate(sup) for sup in supers] + subs = [translate(sub) for sub in subs] + return (name, supers, subs) + + def _deal_with_super_sub(self, string: str, style='plain') -> str: + name, supers, subs = self._split_super_sub(string) + + # apply the style only to the name + if style == 'bold': + name = "\\mathbf{{{}}}".format(name) + + # glue all items together: + if supers: + name += "^{%s}" % " ".join(supers) + if subs: + name += "_{%s}" % " ".join(subs) + + return name + + def _print_Relational(self, expr): + if self._settings['itex']: + gt = r"\gt" + lt = r"\lt" + else: + gt = ">" + lt = "<" + + charmap = { + "==": "=", + ">": gt, + "<": lt, + ">=": r"\geq", + "<=": r"\leq", + "!=": r"\neq", + } + + return "%s %s %s" % (self._print(expr.lhs), + charmap[expr.rel_op], self._print(expr.rhs)) + + def _print_Piecewise(self, expr): + ecpairs = [r"%s & \text{for}\: %s" % (self._print(e), self._print(c)) + for e, c in expr.args[:-1]] + if expr.args[-1].cond == true: + ecpairs.append(r"%s & \text{otherwise}" % + self._print(expr.args[-1].expr)) + else: + ecpairs.append(r"%s & \text{for}\: %s" % + (self._print(expr.args[-1].expr), + self._print(expr.args[-1].cond))) + tex = r"\begin{cases} %s \end{cases}" + return tex % r" \\".join(ecpairs) + + def _print_matrix_contents(self, expr): + lines = [] + + for line in range(expr.rows): # horrible, should be 'rows' + lines.append(" & ".join([self._print(i) for i in expr[line, :]])) + + mat_str = self._settings['mat_str'] + if mat_str is None: + if self._settings['mode'] == 'inline': + mat_str = 'smallmatrix' + else: + if (expr.cols <= 10) is True: + mat_str = 'matrix' + else: + mat_str = 'array' + + out_str = r'\begin{%MATSTR%}%s\end{%MATSTR%}' + out_str = out_str.replace('%MATSTR%', mat_str) + if mat_str == 'array': + out_str = out_str.replace('%s', '{' + 'c'*expr.cols + '}%s') + return out_str % r"\\".join(lines) + + def _print_MatrixBase(self, expr): + out_str = self._print_matrix_contents(expr) + if self._settings['mat_delim']: + left_delim = self._settings['mat_delim'] + right_delim = self._delim_dict[left_delim] + out_str = r'\left' + left_delim + out_str + \ + r'\right' + right_delim + return out_str + + def _print_MatrixElement(self, expr): + matrix_part = self.parenthesize(expr.parent, PRECEDENCE['Atom'], strict=True) + index_part = f"{self._print(expr.i)},{self._print(expr.j)}" + return f"{{{matrix_part}}}_{{{index_part}}}" + + def _print_MatrixSlice(self, expr): + def latexslice(x, dim): + x = list(x) + if x[2] == 1: + del x[2] + if x[0] == 0: + x[0] = None + if x[1] == dim: + x[1] = None + return ':'.join(self._print(xi) if xi is not None else '' for xi in x) + return (self.parenthesize(expr.parent, PRECEDENCE["Atom"], strict=True) + r'\left[' + + latexslice(expr.rowslice, expr.parent.rows) + ', ' + + latexslice(expr.colslice, expr.parent.cols) + r'\right]') + + def _print_BlockMatrix(self, expr): + return self._print(expr.blocks) + + def _print_Transpose(self, expr): + mat = expr.arg + from sympy.matrices import MatrixSymbol, BlockMatrix + if (not isinstance(mat, MatrixSymbol) and + not isinstance(mat, BlockMatrix) and mat.is_MatrixExpr): + return r"\left(%s\right)^{T}" % self._print(mat) + else: + s = self.parenthesize(mat, precedence_traditional(expr), True) + if '^' in s: + return r"\left(%s\right)^{T}" % s + else: + return "%s^{T}" % s + + def _print_Trace(self, expr): + mat = expr.arg + return r"\operatorname{tr}\left(%s \right)" % self._print(mat) + + def _print_Adjoint(self, expr): + style_to_latex = { + "dagger" : r"\dagger", + "star" : r"\ast", + "hermitian": r"\mathsf{H}" + } + adjoint_style = style_to_latex.get(self._settings["adjoint_style"], r"\dagger") + mat = expr.arg + from sympy.matrices import MatrixSymbol, BlockMatrix + if (not isinstance(mat, MatrixSymbol) and + not isinstance(mat, BlockMatrix) and mat.is_MatrixExpr): + return r"\left(%s\right)^{%s}" % (self._print(mat), adjoint_style) + else: + s = self.parenthesize(mat, precedence_traditional(expr), True) + if '^' in s: + return r"\left(%s\right)^{%s}" % (s, adjoint_style) + else: + return r"%s^{%s}" % (s, adjoint_style) + + def _print_MatMul(self, expr): + from sympy import MatMul + + # Parenthesize nested MatMul but not other types of Mul objects: + parens = lambda x: self._print(x) if isinstance(x, Mul) and not isinstance(x, MatMul) else \ + self.parenthesize(x, precedence_traditional(expr), False) + + args = list(expr.args) + if expr.could_extract_minus_sign(): + if args[0] == -1: + args = args[1:] + else: + args[0] = -args[0] + return '- ' + ' '.join(map(parens, args)) + else: + return ' '.join(map(parens, args)) + + def _print_DotProduct(self, expr): + level = precedence_traditional(expr) + left, right = expr.args + return rf"{self.parenthesize(left, level)} \cdot {self.parenthesize(right, level)}" + + def _print_Determinant(self, expr): + mat = expr.arg + if mat.is_MatrixExpr: + from sympy.matrices.expressions.blockmatrix import BlockMatrix + if isinstance(mat, BlockMatrix): + return r"\left|{%s}\right|" % self._print_matrix_contents(mat.blocks) + return r"\left|{%s}\right|" % self._print(mat) + return r"\left|{%s}\right|" % self._print_matrix_contents(mat) + + + def _print_Mod(self, expr, exp=None): + if exp is not None: + return r'\left(%s \bmod %s\right)^{%s}' % \ + (self.parenthesize(expr.args[0], PRECEDENCE['Mul'], + strict=True), + self.parenthesize(expr.args[1], PRECEDENCE['Mul'], + strict=True), + exp) + return r'%s \bmod %s' % (self.parenthesize(expr.args[0], + PRECEDENCE['Mul'], + strict=True), + self.parenthesize(expr.args[1], + PRECEDENCE['Mul'], + strict=True)) + + def _print_HadamardProduct(self, expr): + args = expr.args + prec = PRECEDENCE['Pow'] + parens = self.parenthesize + + return r' \circ '.join( + (parens(arg, prec, strict=True) for arg in args)) + + def _print_HadamardPower(self, expr): + if precedence_traditional(expr.exp) < PRECEDENCE["Mul"]: + template = r"%s^{\circ \left({%s}\right)}" + else: + template = r"%s^{\circ {%s}}" + return self._helper_print_standard_power(expr, template) + + def _print_KroneckerProduct(self, expr): + args = expr.args + prec = PRECEDENCE['Pow'] + parens = self.parenthesize + + return r' \otimes '.join( + (parens(arg, prec, strict=True) for arg in args)) + + def _print_MatPow(self, expr): + base, exp = expr.base, expr.exp + from sympy.matrices import MatrixSymbol + if not isinstance(base, MatrixSymbol) and base.is_MatrixExpr: + return "\\left(%s\\right)^{%s}" % (self._print(base), + self._print(exp)) + else: + base_str = self._print(base) + if '^' in base_str: + return r"\left(%s\right)^{%s}" % (base_str, self._print(exp)) + else: + return "%s^{%s}" % (base_str, self._print(exp)) + + def _print_MatrixSymbol(self, expr): + return self._print_Symbol(expr, style=self._settings[ + 'mat_symbol_style']) + + def _print_ZeroMatrix(self, Z): + return "0" if self._settings[ + 'mat_symbol_style'] == 'plain' else r"\mathbf{0}" + + def _print_OneMatrix(self, O): + return "1" if self._settings[ + 'mat_symbol_style'] == 'plain' else r"\mathbf{1}" + + def _print_Identity(self, I): + return r"\mathbb{I}" if self._settings[ + 'mat_symbol_style'] == 'plain' else r"\mathbf{I}" + + def _print_PermutationMatrix(self, P): + perm_str = self._print(P.args[0]) + return "P_{%s}" % perm_str + + def _print_NDimArray(self, expr: NDimArray): + + if expr.rank() == 0: + return self._print(expr[()]) + + mat_str = self._settings['mat_str'] + if mat_str is None: + if self._settings['mode'] == 'inline': + mat_str = 'smallmatrix' + else: + if (expr.rank() == 0) or (expr.shape[-1] <= 10): + mat_str = 'matrix' + else: + mat_str = 'array' + block_str = r'\begin{%MATSTR%}%s\end{%MATSTR%}' + block_str = block_str.replace('%MATSTR%', mat_str) + if mat_str == 'array': + block_str = block_str.replace('%s', '{' + 'c'*expr.shape[0] + '}%s') + + if self._settings['mat_delim']: + left_delim: str = self._settings['mat_delim'] + right_delim = self._delim_dict[left_delim] + block_str = r'\left' + left_delim + block_str + \ + r'\right' + right_delim + + if expr.rank() == 0: + return block_str % "" + + level_str: list[list[str]] = [[] for i in range(expr.rank() + 1)] + shape_ranges = [list(range(i)) for i in expr.shape] + for outer_i in itertools.product(*shape_ranges): + level_str[-1].append(self._print(expr[outer_i])) + even = True + for back_outer_i in range(expr.rank()-1, -1, -1): + if len(level_str[back_outer_i+1]) < expr.shape[back_outer_i]: + break + if even: + level_str[back_outer_i].append( + r" & ".join(level_str[back_outer_i+1])) + else: + level_str[back_outer_i].append( + block_str % (r"\\".join(level_str[back_outer_i+1]))) + if len(level_str[back_outer_i+1]) == 1: + level_str[back_outer_i][-1] = r"\left[" + \ + level_str[back_outer_i][-1] + r"\right]" + even = not even + level_str[back_outer_i+1] = [] + + out_str = level_str[0][0] + + if expr.rank() % 2 == 1: + out_str = block_str % out_str + + return out_str + + def _printer_tensor_indices(self, name, indices, index_map: dict): + out_str = self._print(name) + last_valence = None + prev_map = None + for index in indices: + new_valence = index.is_up + if ((index in index_map) or prev_map) and \ + last_valence == new_valence: + out_str += "," + if last_valence != new_valence: + if last_valence is not None: + out_str += "}" + if index.is_up: + out_str += "{}^{" + else: + out_str += "{}_{" + out_str += self._print(index.args[0]) + if index in index_map: + out_str += "=" + out_str += self._print(index_map[index]) + prev_map = True + else: + prev_map = False + last_valence = new_valence + if last_valence is not None: + out_str += "}" + return out_str + + def _print_Tensor(self, expr): + name = expr.args[0].args[0] + indices = expr.get_indices() + return self._printer_tensor_indices(name, indices, {}) + + def _print_TensorElement(self, expr): + name = expr.expr.args[0].args[0] + indices = expr.expr.get_indices() + index_map = expr.index_map + return self._printer_tensor_indices(name, indices, index_map) + + def _print_TensMul(self, expr): + # prints expressions like "A(a)", "3*A(a)", "(1+x)*A(a)" + sign, args = expr._get_args_for_traditional_printer() + return sign + "".join( + [self.parenthesize(arg, precedence(expr)) for arg in args] + ) + + def _print_TensAdd(self, expr): + a = [] + args = expr.args + for x in args: + a.append(self.parenthesize(x, precedence(expr))) + a.sort() + s = ' + '.join(a) + s = s.replace('+ -', '- ') + return s + + def _print_TensorIndex(self, expr): + return "{}%s{%s}" % ( + "^" if expr.is_up else "_", + self._print(expr.args[0]) + ) + + def _print_PartialDerivative(self, expr): + if len(expr.variables) == 1: + return r"\frac{\partial}{\partial {%s}}{%s}" % ( + self._print(expr.variables[0]), + self.parenthesize(expr.expr, PRECEDENCE["Mul"], False) + ) + else: + return r"\frac{\partial^{%s}}{%s}{%s}" % ( + len(expr.variables), + " ".join([r"\partial {%s}" % self._print(i) for i in expr.variables]), + self.parenthesize(expr.expr, PRECEDENCE["Mul"], False) + ) + + def _print_ArraySymbol(self, expr): + return self._print(expr.name) + + def _print_ArrayElement(self, expr): + return "{{%s}_{%s}}" % ( + self.parenthesize(expr.name, PRECEDENCE["Func"], True), + ", ".join([f"{self._print(i)}" for i in expr.indices])) + + def _print_UniversalSet(self, expr): + return r"\mathbb{U}" + + def _print_frac(self, expr, exp=None): + if exp is None: + return r"\operatorname{frac}{\left(%s\right)}" % self._print(expr.args[0]) + else: + return r"\operatorname{frac}{\left(%s\right)}^{%s}" % ( + self._print(expr.args[0]), exp) + + def _print_tuple(self, expr): + if self._settings['decimal_separator'] == 'comma': + sep = ";" + elif self._settings['decimal_separator'] == 'period': + sep = "," + else: + raise ValueError('Unknown Decimal Separator') + + if len(expr) == 1: + # 1-tuple needs a trailing separator + return self._add_parens_lspace(self._print(expr[0]) + sep) + else: + return self._add_parens_lspace( + (sep + r" \ ").join([self._print(i) for i in expr])) + + def _print_TensorProduct(self, expr): + elements = [self._print(a) for a in expr.args] + return r' \otimes '.join(elements) + + def _print_WedgeProduct(self, expr): + elements = [self._print(a) for a in expr.args] + return r' \wedge '.join(elements) + + def _print_Tuple(self, expr): + return self._print_tuple(expr) + + def _print_list(self, expr): + if self._settings['decimal_separator'] == 'comma': + return r"\left[ %s\right]" % \ + r"; \ ".join([self._print(i) for i in expr]) + elif self._settings['decimal_separator'] == 'period': + return r"\left[ %s\right]" % \ + r", \ ".join([self._print(i) for i in expr]) + else: + raise ValueError('Unknown Decimal Separator') + + + def _print_dict(self, d): + keys = sorted(d.keys(), key=default_sort_key) + items = [] + + for key in keys: + val = d[key] + items.append("%s : %s" % (self._print(key), self._print(val))) + + return r"\left\{ %s\right\}" % r", \ ".join(items) + + def _print_Dict(self, expr): + return self._print_dict(expr) + + def _print_DiracDelta(self, expr, exp=None): + if len(expr.args) == 1 or expr.args[1] == 0: + tex = r"\delta\left(%s\right)" % self._print(expr.args[0]) + else: + tex = r"\delta^{\left( %s \right)}\left( %s \right)" % ( + self._print(expr.args[1]), self._print(expr.args[0])) + if exp: + tex = r"\left(%s\right)^{%s}" % (tex, exp) + return tex + + def _print_SingularityFunction(self, expr, exp=None): + shift = self._print(expr.args[0] - expr.args[1]) + power = self._print(expr.args[2]) + tex = r"{\left\langle %s \right\rangle}^{%s}" % (shift, power) + if exp is not None: + tex = r"{\left({\langle %s \rangle}^{%s}\right)}^{%s}" % (shift, power, exp) + return tex + + def _print_Heaviside(self, expr, exp=None): + pargs = ', '.join(self._print(arg) for arg in expr.pargs) + tex = r"\theta\left(%s\right)" % pargs + if exp: + tex = r"\left(%s\right)^{%s}" % (tex, exp) + return tex + + def _print_KroneckerDelta(self, expr, exp=None): + i = self._print(expr.args[0]) + j = self._print(expr.args[1]) + if expr.args[0].is_Atom and expr.args[1].is_Atom: + tex = r'\delta_{%s %s}' % (i, j) + else: + tex = r'\delta_{%s, %s}' % (i, j) + if exp is not None: + tex = r'\left(%s\right)^{%s}' % (tex, exp) + return tex + + def _print_LeviCivita(self, expr, exp=None): + indices = map(self._print, expr.args) + if all(x.is_Atom for x in expr.args): + tex = r'\varepsilon_{%s}' % " ".join(indices) + else: + tex = r'\varepsilon_{%s}' % ", ".join(indices) + if exp: + tex = r'\left(%s\right)^{%s}' % (tex, exp) + return tex + + def _print_RandomDomain(self, d): + if hasattr(d, 'as_boolean'): + return '\\text{Domain: }' + self._print(d.as_boolean()) + elif hasattr(d, 'set'): + return ('\\text{Domain: }' + self._print(d.symbols) + ' \\in ' + + self._print(d.set)) + elif hasattr(d, 'symbols'): + return '\\text{Domain on }' + self._print(d.symbols) + else: + return self._print(None) + + def _print_FiniteSet(self, s): + items = sorted(s.args, key=default_sort_key) + return self._print_set(items) + + def _print_set(self, s): + items = sorted(s, key=default_sort_key) + if self._settings['decimal_separator'] == 'comma': + items = "; ".join(map(self._print, items)) + elif self._settings['decimal_separator'] == 'period': + items = ", ".join(map(self._print, items)) + else: + raise ValueError('Unknown Decimal Separator') + return r"\left\{%s\right\}" % items + + + _print_frozenset = _print_set + + def _print_Range(self, s): + def _print_symbolic_range(): + # Symbolic Range that cannot be resolved + if s.args[0] == 0: + if s.args[2] == 1: + cont = self._print(s.args[1]) + else: + cont = ", ".join(self._print(arg) for arg in s.args) + else: + if s.args[2] == 1: + cont = ", ".join(self._print(arg) for arg in s.args[:2]) + else: + cont = ", ".join(self._print(arg) for arg in s.args) + + return(f"\\text{{Range}}\\left({cont}\\right)") + + dots = object() + + if s.start.is_infinite and s.stop.is_infinite: + if s.step.is_positive: + printset = dots, -1, 0, 1, dots + else: + printset = dots, 1, 0, -1, dots + elif s.start.is_infinite: + printset = dots, s[-1] - s.step, s[-1] + elif s.stop.is_infinite: + it = iter(s) + printset = next(it), next(it), dots + elif s.is_empty is not None: + if (s.size < 4) == True: + printset = tuple(s) + elif s.is_iterable: + it = iter(s) + printset = next(it), next(it), dots, s[-1] + else: + return _print_symbolic_range() + else: + return _print_symbolic_range() + return (r"\left\{" + + r", ".join(self._print(el) if el is not dots else r'\ldots' for el in printset) + + r"\right\}") + + def __print_number_polynomial(self, expr, letter, exp=None): + if len(expr.args) == 2: + if exp is not None: + return r"%s_{%s}^{%s}\left(%s\right)" % (letter, + self._print(expr.args[0]), exp, + self._print(expr.args[1])) + return r"%s_{%s}\left(%s\right)" % (letter, + self._print(expr.args[0]), self._print(expr.args[1])) + + tex = r"%s_{%s}" % (letter, self._print(expr.args[0])) + if exp is not None: + tex = r"%s^{%s}" % (tex, exp) + return tex + + def _print_bernoulli(self, expr, exp=None): + return self.__print_number_polynomial(expr, "B", exp) + + def _print_genocchi(self, expr, exp=None): + return self.__print_number_polynomial(expr, "G", exp) + + def _print_bell(self, expr, exp=None): + if len(expr.args) == 3: + tex1 = r"B_{%s, %s}" % (self._print(expr.args[0]), + self._print(expr.args[1])) + tex2 = r"\left(%s\right)" % r", ".join(self._print(el) for + el in expr.args[2]) + if exp is not None: + tex = r"%s^{%s}%s" % (tex1, exp, tex2) + else: + tex = tex1 + tex2 + return tex + return self.__print_number_polynomial(expr, "B", exp) + + def _print_fibonacci(self, expr, exp=None): + return self.__print_number_polynomial(expr, "F", exp) + + def _print_lucas(self, expr, exp=None): + tex = r"L_{%s}" % self._print(expr.args[0]) + if exp is not None: + tex = r"%s^{%s}" % (tex, exp) + return tex + + def _print_tribonacci(self, expr, exp=None): + return self.__print_number_polynomial(expr, "T", exp) + + def _print_mobius(self, expr, exp=None): + if exp is None: + return r'\mu\left(%s\right)' % self._print(expr.args[0]) + return r'\mu^{%s}\left(%s\right)' % (exp, self._print(expr.args[0])) + + def _print_SeqFormula(self, s): + dots = object() + if len(s.start.free_symbols) > 0 or len(s.stop.free_symbols) > 0: + return r"\left\{%s\right\}_{%s=%s}^{%s}" % ( + self._print(s.formula), + self._print(s.variables[0]), + self._print(s.start), + self._print(s.stop) + ) + if s.start is S.NegativeInfinity: + stop = s.stop + printset = (dots, s.coeff(stop - 3), s.coeff(stop - 2), + s.coeff(stop - 1), s.coeff(stop)) + elif s.stop is S.Infinity or s.length > 4: + printset = s[:4] + printset.append(dots) + else: + printset = tuple(s) + + return (r"\left[" + + r", ".join(self._print(el) if el is not dots else r'\ldots' for el in printset) + + r"\right]") + + _print_SeqPer = _print_SeqFormula + _print_SeqAdd = _print_SeqFormula + _print_SeqMul = _print_SeqFormula + + def _print_Interval(self, i): + if i.start == i.end: + return r"\left\{%s\right\}" % self._print(i.start) + + else: + if i.left_open: + left = '(' + else: + left = '[' + + if i.right_open: + right = ')' + else: + right = ']' + + return r"\left%s%s, %s\right%s" % \ + (left, self._print(i.start), self._print(i.end), right) + + def _print_AccumulationBounds(self, i): + return r"\left\langle %s, %s\right\rangle" % \ + (self._print(i.min), self._print(i.max)) + + def _print_Union(self, u): + prec = precedence_traditional(u) + args_str = [self.parenthesize(i, prec) for i in u.args] + return r" \cup ".join(args_str) + + def _print_Complement(self, u): + prec = precedence_traditional(u) + args_str = [self.parenthesize(i, prec) for i in u.args] + return r" \setminus ".join(args_str) + + def _print_Intersection(self, u): + prec = precedence_traditional(u) + args_str = [self.parenthesize(i, prec) for i in u.args] + return r" \cap ".join(args_str) + + def _print_SymmetricDifference(self, u): + prec = precedence_traditional(u) + args_str = [self.parenthesize(i, prec) for i in u.args] + return r" \triangle ".join(args_str) + + def _print_ProductSet(self, p): + prec = precedence_traditional(p) + if len(p.sets) >= 1 and not has_variety(p.sets): + return self.parenthesize(p.sets[0], prec) + "^{%d}" % len(p.sets) + return r" \times ".join( + self.parenthesize(set, prec) for set in p.sets) + + def _print_EmptySet(self, e): + return r"\emptyset" + + def _print_Naturals(self, n): + return r"\mathbb{N}" + + def _print_Naturals0(self, n): + return r"\mathbb{N}_0" + + def _print_Integers(self, i): + return r"\mathbb{Z}" + + def _print_Rationals(self, i): + return r"\mathbb{Q}" + + def _print_Reals(self, i): + return r"\mathbb{R}" + + def _print_Complexes(self, i): + return r"\mathbb{C}" + + def _print_ImageSet(self, s): + expr = s.lamda.expr + sig = s.lamda.signature + xys = ((self._print(x), self._print(y)) for x, y in zip(sig, s.base_sets)) + xinys = r", ".join(r"%s \in %s" % xy for xy in xys) + return r"\left\{%s\; \middle|\; %s\right\}" % (self._print(expr), xinys) + + def _print_ConditionSet(self, s): + vars_print = ', '.join([self._print(var) for var in Tuple(s.sym)]) + if s.base_set is S.UniversalSet: + return r"\left\{%s\; \middle|\; %s \right\}" % \ + (vars_print, self._print(s.condition)) + + return r"\left\{%s\; \middle|\; %s \in %s \wedge %s \right\}" % ( + vars_print, + vars_print, + self._print(s.base_set), + self._print(s.condition)) + + def _print_PowerSet(self, expr): + arg_print = self._print(expr.args[0]) + return r"\mathcal{{P}}\left({}\right)".format(arg_print) + + def _print_ComplexRegion(self, s): + vars_print = ', '.join([self._print(var) for var in s.variables]) + return r"\left\{%s\; \middle|\; %s \in %s \right\}" % ( + self._print(s.expr), + vars_print, + self._print(s.sets)) + + def _print_Contains(self, e): + return r"%s \in %s" % tuple(self._print(a) for a in e.args) + + def _print_FourierSeries(self, s): + if s.an.formula is S.Zero and s.bn.formula is S.Zero: + return self._print(s.a0) + return self._print_Add(s.truncate()) + r' + \ldots' + + def _print_FormalPowerSeries(self, s): + return self._print_Add(s.infinite) + + def _print_FiniteField(self, expr): + return r"\mathbb{F}_{%s}" % expr.mod + + def _print_IntegerRing(self, expr): + return r"\mathbb{Z}" + + def _print_RationalField(self, expr): + return r"\mathbb{Q}" + + def _print_RealField(self, expr): + return r"\mathbb{R}" + + def _print_ComplexField(self, expr): + return r"\mathbb{C}" + + def _print_PolynomialRing(self, expr): + domain = self._print(expr.domain) + symbols = ", ".join(map(self._print, expr.symbols)) + return r"%s\left[%s\right]" % (domain, symbols) + + def _print_FractionField(self, expr): + domain = self._print(expr.domain) + symbols = ", ".join(map(self._print, expr.symbols)) + return r"%s\left(%s\right)" % (domain, symbols) + + def _print_PolynomialRingBase(self, expr): + domain = self._print(expr.domain) + symbols = ", ".join(map(self._print, expr.symbols)) + inv = "" + if not expr.is_Poly: + inv = r"S_<^{-1}" + return r"%s%s\left[%s\right]" % (inv, domain, symbols) + + def _print_Poly(self, poly): + cls = poly.__class__.__name__ + terms = [] + for monom, coeff in poly.terms(): + s_monom = '' + for i, exp in enumerate(monom): + if exp > 0: + if exp == 1: + s_monom += self._print(poly.gens[i]) + else: + s_monom += self._print(pow(poly.gens[i], exp)) + + if coeff.is_Add: + if s_monom: + s_coeff = r"\left(%s\right)" % self._print(coeff) + else: + s_coeff = self._print(coeff) + else: + if s_monom: + if coeff is S.One: + terms.extend(['+', s_monom]) + continue + + if coeff is S.NegativeOne: + terms.extend(['-', s_monom]) + continue + + s_coeff = self._print(coeff) + + if not s_monom: + s_term = s_coeff + else: + s_term = s_coeff + " " + s_monom + + if s_term.startswith('-'): + terms.extend(['-', s_term[1:]]) + else: + terms.extend(['+', s_term]) + + if terms[0] in ('-', '+'): + modifier = terms.pop(0) + + if modifier == '-': + terms[0] = '-' + terms[0] + + expr = ' '.join(terms) + gens = list(map(self._print, poly.gens)) + domain = "domain=%s" % self._print(poly.get_domain()) + + args = ", ".join([expr] + gens + [domain]) + if cls in accepted_latex_functions: + tex = r"\%s {\left(%s \right)}" % (cls, args) + else: + tex = r"\operatorname{%s}{\left( %s \right)}" % (cls, args) + + return tex + + def _print_ComplexRootOf(self, root): + cls = root.__class__.__name__ + if cls == "ComplexRootOf": + cls = "CRootOf" + expr = self._print(root.expr) + index = root.index + if cls in accepted_latex_functions: + return r"\%s {\left(%s, %d\right)}" % (cls, expr, index) + else: + return r"\operatorname{%s} {\left(%s, %d\right)}" % (cls, expr, + index) + + def _print_RootSum(self, expr): + cls = expr.__class__.__name__ + args = [self._print(expr.expr)] + + if expr.fun is not S.IdentityFunction: + args.append(self._print(expr.fun)) + + if cls in accepted_latex_functions: + return r"\%s {\left(%s\right)}" % (cls, ", ".join(args)) + else: + return r"\operatorname{%s} {\left(%s\right)}" % (cls, + ", ".join(args)) + + def _print_OrdinalOmega(self, expr): + return r"\omega" + + def _print_OmegaPower(self, expr): + exp, mul = expr.args + if mul != 1: + if exp != 1: + return r"{} \omega^{{{}}}".format(mul, exp) + else: + return r"{} \omega".format(mul) + else: + if exp != 1: + return r"\omega^{{{}}}".format(exp) + else: + return r"\omega" + + def _print_Ordinal(self, expr): + return " + ".join([self._print(arg) for arg in expr.args]) + + def _print_PolyElement(self, poly): + mul_symbol = self._settings['mul_symbol_latex'] + return poly.str(self, PRECEDENCE, "{%s}^{%d}", mul_symbol) + + def _print_FracElement(self, frac): + if frac.denom == 1: + return self._print(frac.numer) + else: + numer = self._print(frac.numer) + denom = self._print(frac.denom) + return r"\frac{%s}{%s}" % (numer, denom) + + def _print_euler(self, expr, exp=None): + m, x = (expr.args[0], None) if len(expr.args) == 1 else expr.args + tex = r"E_{%s}" % self._print(m) + if exp is not None: + tex = r"%s^{%s}" % (tex, exp) + if x is not None: + tex = r"%s\left(%s\right)" % (tex, self._print(x)) + return tex + + def _print_catalan(self, expr, exp=None): + tex = r"C_{%s}" % self._print(expr.args[0]) + if exp is not None: + tex = r"%s^{%s}" % (tex, exp) + return tex + + def _print_UnifiedTransform(self, expr, s, inverse=False): + return r"\mathcal{{{}}}{}_{{{}}}\left[{}\right]\left({}\right)".format(s, '^{-1}' if inverse else '', self._print(expr.args[1]), self._print(expr.args[0]), self._print(expr.args[2])) + + def _print_MellinTransform(self, expr): + return self._print_UnifiedTransform(expr, 'M') + + def _print_InverseMellinTransform(self, expr): + return self._print_UnifiedTransform(expr, 'M', True) + + def _print_LaplaceTransform(self, expr): + return self._print_UnifiedTransform(expr, 'L') + + def _print_InverseLaplaceTransform(self, expr): + return self._print_UnifiedTransform(expr, 'L', True) + + def _print_FourierTransform(self, expr): + return self._print_UnifiedTransform(expr, 'F') + + def _print_InverseFourierTransform(self, expr): + return self._print_UnifiedTransform(expr, 'F', True) + + def _print_SineTransform(self, expr): + return self._print_UnifiedTransform(expr, 'SIN') + + def _print_InverseSineTransform(self, expr): + return self._print_UnifiedTransform(expr, 'SIN', True) + + def _print_CosineTransform(self, expr): + return self._print_UnifiedTransform(expr, 'COS') + + def _print_InverseCosineTransform(self, expr): + return self._print_UnifiedTransform(expr, 'COS', True) + + def _print_DMP(self, p): + try: + if p.ring is not None: + # TODO incorporate order + return self._print(p.ring.to_sympy(p)) + except SympifyError: + pass + return self._print(repr(p)) + + def _print_DMF(self, p): + return self._print_DMP(p) + + def _print_Object(self, object): + return self._print(Symbol(object.name)) + + def _print_LambertW(self, expr, exp=None): + arg0 = self._print(expr.args[0]) + exp = r"^{%s}" % (exp,) if exp is not None else "" + if len(expr.args) == 1: + result = r"W%s\left(%s\right)" % (exp, arg0) + else: + arg1 = self._print(expr.args[1]) + result = "W{0}_{{{1}}}\\left({2}\\right)".format(exp, arg1, arg0) + return result + + def _print_Expectation(self, expr): + return r"\operatorname{{E}}\left[{}\right]".format(self._print(expr.args[0])) + + def _print_Variance(self, expr): + return r"\operatorname{{Var}}\left({}\right)".format(self._print(expr.args[0])) + + def _print_Covariance(self, expr): + return r"\operatorname{{Cov}}\left({}\right)".format(", ".join(self._print(arg) for arg in expr.args)) + + def _print_Probability(self, expr): + return r"\operatorname{{P}}\left({}\right)".format(self._print(expr.args[0])) + + def _print_Morphism(self, morphism): + domain = self._print(morphism.domain) + codomain = self._print(morphism.codomain) + return "%s\\rightarrow %s" % (domain, codomain) + + def _print_TransferFunction(self, expr): + num, den = self._print(expr.num), self._print(expr.den) + return r"\frac{%s}{%s}" % (num, den) + + def _print_Series(self, expr): + args = list(expr.args) + parens = lambda x: self.parenthesize(x, precedence_traditional(expr), + False) + return ' '.join(map(parens, args)) + + def _print_MIMOSeries(self, expr): + from sympy.physics.control.lti import MIMOParallel + args = list(expr.args)[::-1] + parens = lambda x: self.parenthesize(x, precedence_traditional(expr), + False) if isinstance(x, MIMOParallel) else self._print(x) + return r"\cdot".join(map(parens, args)) + + def _print_Parallel(self, expr): + return ' + '.join(map(self._print, expr.args)) + + def _print_MIMOParallel(self, expr): + return ' + '.join(map(self._print, expr.args)) + + def _print_Feedback(self, expr): + from sympy.physics.control import TransferFunction, Series + + num, tf = expr.sys1, TransferFunction(1, 1, expr.var) + num_arg_list = list(num.args) if isinstance(num, Series) else [num] + den_arg_list = list(expr.sys2.args) if \ + isinstance(expr.sys2, Series) else [expr.sys2] + den_term_1 = tf + + if isinstance(num, Series) and isinstance(expr.sys2, Series): + den_term_2 = Series(*num_arg_list, *den_arg_list) + elif isinstance(num, Series) and isinstance(expr.sys2, TransferFunction): + if expr.sys2 == tf: + den_term_2 = Series(*num_arg_list) + else: + den_term_2 = tf, Series(*num_arg_list, expr.sys2) + elif isinstance(num, TransferFunction) and isinstance(expr.sys2, Series): + if num == tf: + den_term_2 = Series(*den_arg_list) + else: + den_term_2 = Series(num, *den_arg_list) + else: + if num == tf: + den_term_2 = Series(*den_arg_list) + elif expr.sys2 == tf: + den_term_2 = Series(*num_arg_list) + else: + den_term_2 = Series(*num_arg_list, *den_arg_list) + + numer = self._print(num) + denom_1 = self._print(den_term_1) + denom_2 = self._print(den_term_2) + _sign = "+" if expr.sign == -1 else "-" + + return r"\frac{%s}{%s %s %s}" % (numer, denom_1, _sign, denom_2) + + def _print_MIMOFeedback(self, expr): + from sympy.physics.control import MIMOSeries + inv_mat = self._print(MIMOSeries(expr.sys2, expr.sys1)) + sys1 = self._print(expr.sys1) + _sign = "+" if expr.sign == -1 else "-" + return r"\left(I_{\tau} %s %s\right)^{-1} \cdot %s" % (_sign, inv_mat, sys1) + + def _print_TransferFunctionMatrix(self, expr): + mat = self._print(expr._expr_mat) + return r"%s_\tau" % mat + + def _print_DFT(self, expr): + return r"\text{{{}}}_{{{}}}".format(expr.__class__.__name__, expr.n) + _print_IDFT = _print_DFT + + def _print_NamedMorphism(self, morphism): + pretty_name = self._print(Symbol(morphism.name)) + pretty_morphism = self._print_Morphism(morphism) + return "%s:%s" % (pretty_name, pretty_morphism) + + def _print_IdentityMorphism(self, morphism): + from sympy.categories import NamedMorphism + return self._print_NamedMorphism(NamedMorphism( + morphism.domain, morphism.codomain, "id")) + + def _print_CompositeMorphism(self, morphism): + # All components of the morphism have names and it is thus + # possible to build the name of the composite. + component_names_list = [self._print(Symbol(component.name)) for + component in morphism.components] + component_names_list.reverse() + component_names = "\\circ ".join(component_names_list) + ":" + + pretty_morphism = self._print_Morphism(morphism) + return component_names + pretty_morphism + + def _print_Category(self, morphism): + return r"\mathbf{{{}}}".format(self._print(Symbol(morphism.name))) + + def _print_Diagram(self, diagram): + if not diagram.premises: + # This is an empty diagram. + return self._print(S.EmptySet) + + latex_result = self._print(diagram.premises) + if diagram.conclusions: + latex_result += "\\Longrightarrow %s" % \ + self._print(diagram.conclusions) + + return latex_result + + def _print_DiagramGrid(self, grid): + latex_result = "\\begin{array}{%s}\n" % ("c" * grid.width) + + for i in range(grid.height): + for j in range(grid.width): + if grid[i, j]: + latex_result += latex(grid[i, j]) + latex_result += " " + if j != grid.width - 1: + latex_result += "& " + + if i != grid.height - 1: + latex_result += "\\\\" + latex_result += "\n" + + latex_result += "\\end{array}\n" + return latex_result + + def _print_FreeModule(self, M): + return '{{{}}}^{{{}}}'.format(self._print(M.ring), self._print(M.rank)) + + def _print_FreeModuleElement(self, m): + # Print as row vector for convenience, for now. + return r"\left[ {} \right]".format(",".join( + '{' + self._print(x) + '}' for x in m)) + + def _print_SubModule(self, m): + gens = [[self._print(m.ring.to_sympy(x)) for x in g] for g in m.gens] + curly = lambda o: r"{" + o + r"}" + square = lambda o: r"\left[ " + o + r" \right]" + gens_latex = ",".join(curly(square(",".join(curly(x) for x in g))) for g in gens) + return r"\left\langle {} \right\rangle".format(gens_latex) + + def _print_SubQuotientModule(self, m): + gens_latex = ",".join(["{" + self._print(g) + "}" for g in m.gens]) + return r"\left\langle {} \right\rangle".format(gens_latex) + + def _print_ModuleImplementedIdeal(self, m): + gens = [m.ring.to_sympy(x) for [x] in m._module.gens] + gens_latex = ",".join('{' + self._print(x) + '}' for x in gens) + return r"\left\langle {} \right\rangle".format(gens_latex) + + def _print_Quaternion(self, expr): + # TODO: This expression is potentially confusing, + # shall we print it as `Quaternion( ... )`? + s = [self.parenthesize(i, PRECEDENCE["Mul"], strict=True) + for i in expr.args] + a = [s[0]] + [i+" "+j for i, j in zip(s[1:], "ijk")] + return " + ".join(a) + + def _print_QuotientRing(self, R): + # TODO nicer fractions for few generators... + return r"\frac{{{}}}{{{}}}".format(self._print(R.ring), + self._print(R.base_ideal)) + + def _print_QuotientRingElement(self, x): + x_latex = self._print(x.ring.to_sympy(x)) + return r"{{{}}} + {{{}}}".format(x_latex, + self._print(x.ring.base_ideal)) + + def _print_QuotientModuleElement(self, m): + data = [m.module.ring.to_sympy(x) for x in m.data] + data_latex = r"\left[ {} \right]".format(",".join( + '{' + self._print(x) + '}' for x in data)) + return r"{{{}}} + {{{}}}".format(data_latex, + self._print(m.module.killed_module)) + + def _print_QuotientModule(self, M): + # TODO nicer fractions for few generators... + return r"\frac{{{}}}{{{}}}".format(self._print(M.base), + self._print(M.killed_module)) + + def _print_MatrixHomomorphism(self, h): + return r"{{{}}} : {{{}}} \to {{{}}}".format(self._print(h._sympy_matrix()), + self._print(h.domain), self._print(h.codomain)) + + def _print_Manifold(self, manifold): + name, supers, subs = self._split_super_sub(manifold.name.name) + + name = r'\text{%s}' % name + if supers: + name += "^{%s}" % " ".join(supers) + if subs: + name += "_{%s}" % " ".join(subs) + + return name + + def _print_Patch(self, patch): + return r'\text{%s}_{%s}' % (self._print(patch.name), self._print(patch.manifold)) + + def _print_CoordSystem(self, coordsys): + return r'\text{%s}^{\text{%s}}_{%s}' % ( + self._print(coordsys.name), self._print(coordsys.patch.name), self._print(coordsys.manifold) + ) + + def _print_CovarDerivativeOp(self, cvd): + return r'\mathbb{\nabla}_{%s}' % self._print(cvd._wrt) + + def _print_BaseScalarField(self, field): + string = field._coord_sys.symbols[field._index].name + return r'\mathbf{{{}}}'.format(self._print(Symbol(string))) + + def _print_BaseVectorField(self, field): + string = field._coord_sys.symbols[field._index].name + return r'\partial_{{{}}}'.format(self._print(Symbol(string))) + + def _print_Differential(self, diff): + field = diff._form_field + if hasattr(field, '_coord_sys'): + string = field._coord_sys.symbols[field._index].name + return r'\operatorname{{d}}{}'.format(self._print(Symbol(string))) + else: + string = self._print(field) + return r'\operatorname{{d}}\left({}\right)'.format(string) + + def _print_Tr(self, p): + # TODO: Handle indices + contents = self._print(p.args[0]) + return r'\operatorname{{tr}}\left({}\right)'.format(contents) + + def _print_totient(self, expr, exp=None): + if exp is not None: + return r'\left(\phi\left(%s\right)\right)^{%s}' % \ + (self._print(expr.args[0]), exp) + return r'\phi\left(%s\right)' % self._print(expr.args[0]) + + def _print_reduced_totient(self, expr, exp=None): + if exp is not None: + return r'\left(\lambda\left(%s\right)\right)^{%s}' % \ + (self._print(expr.args[0]), exp) + return r'\lambda\left(%s\right)' % self._print(expr.args[0]) + + def _print_divisor_sigma(self, expr, exp=None): + if len(expr.args) == 2: + tex = r"_%s\left(%s\right)" % tuple(map(self._print, + (expr.args[1], expr.args[0]))) + else: + tex = r"\left(%s\right)" % self._print(expr.args[0]) + if exp is not None: + return r"\sigma^{%s}%s" % (exp, tex) + return r"\sigma%s" % tex + + def _print_udivisor_sigma(self, expr, exp=None): + if len(expr.args) == 2: + tex = r"_%s\left(%s\right)" % tuple(map(self._print, + (expr.args[1], expr.args[0]))) + else: + tex = r"\left(%s\right)" % self._print(expr.args[0]) + if exp is not None: + return r"\sigma^*^{%s}%s" % (exp, tex) + return r"\sigma^*%s" % tex + + def _print_primenu(self, expr, exp=None): + if exp is not None: + return r'\left(\nu\left(%s\right)\right)^{%s}' % \ + (self._print(expr.args[0]), exp) + return r'\nu\left(%s\right)' % self._print(expr.args[0]) + + def _print_primeomega(self, expr, exp=None): + if exp is not None: + return r'\left(\Omega\left(%s\right)\right)^{%s}' % \ + (self._print(expr.args[0]), exp) + return r'\Omega\left(%s\right)' % self._print(expr.args[0]) + + def _print_Str(self, s): + return str(s.name) + + def _print_float(self, expr): + return self._print(Float(expr)) + + def _print_int(self, expr): + return str(expr) + + def _print_mpz(self, expr): + return str(expr) + + def _print_mpq(self, expr): + return str(expr) + + def _print_fmpz(self, expr): + return str(expr) + + def _print_fmpq(self, expr): + return str(expr) + + def _print_Predicate(self, expr): + return r"\operatorname{{Q}}_{{\text{{{}}}}}".format(latex_escape(str(expr.name))) + + def _print_AppliedPredicate(self, expr): + pred = expr.function + args = expr.arguments + pred_latex = self._print(pred) + args_latex = ', '.join([self._print(a) for a in args]) + return '%s(%s)' % (pred_latex, args_latex) + + def emptyPrinter(self, expr): + # default to just printing as monospace, like would normally be shown + s = super().emptyPrinter(expr) + + return r"\mathtt{\text{%s}}" % latex_escape(s) + + +def translate(s: str) -> str: + r''' + Check for a modifier ending the string. If present, convert the + modifier to latex and translate the rest recursively. + + Given a description of a Greek letter or other special character, + return the appropriate latex. + + Let everything else pass as given. + + >>> from sympy.printing.latex import translate + >>> translate('alphahatdotprime') + "{\\dot{\\hat{\\alpha}}}'" + ''' + # Process the rest + tex = tex_greek_dictionary.get(s) + if tex: + return tex + elif s.lower() in greek_letters_set: + return "\\" + s.lower() + elif s in other_symbols: + return "\\" + s + else: + # Process modifiers, if any, and recurse + for key in sorted(modifier_dict.keys(), key=len, reverse=True): + if s.lower().endswith(key) and len(s) > len(key): + return modifier_dict[key](translate(s[:-len(key)])) + return s + + + +@print_function(LatexPrinter) +def latex(expr, **settings): + r"""Convert the given expression to LaTeX string representation. + + Parameters + ========== + full_prec: boolean, optional + If set to True, a floating point number is printed with full precision. + fold_frac_powers : boolean, optional + Emit ``^{p/q}`` instead of ``^{\frac{p}{q}}`` for fractional powers. + fold_func_brackets : boolean, optional + Fold function brackets where applicable. + fold_short_frac : boolean, optional + Emit ``p / q`` instead of ``\frac{p}{q}`` when the denominator is + simple enough (at most two terms and no powers). The default value is + ``True`` for inline mode, ``False`` otherwise. + inv_trig_style : string, optional + How inverse trig functions should be displayed. Can be one of + ``'abbreviated'``, ``'full'``, or ``'power'``. Defaults to + ``'abbreviated'``. + itex : boolean, optional + Specifies if itex-specific syntax is used, including emitting + ``$$...$$``. + ln_notation : boolean, optional + If set to ``True``, ``\ln`` is used instead of default ``\log``. + long_frac_ratio : float or None, optional + The allowed ratio of the width of the numerator to the width of the + denominator before the printer breaks off long fractions. If ``None`` + (the default value), long fractions are not broken up. + mat_delim : string, optional + The delimiter to wrap around matrices. Can be one of ``'['``, ``'('``, + or the empty string ``''``. Defaults to ``'['``. + mat_str : string, optional + Which matrix environment string to emit. ``'smallmatrix'``, + ``'matrix'``, ``'array'``, etc. Defaults to ``'smallmatrix'`` for + inline mode, ``'matrix'`` for matrices of no more than 10 columns, and + ``'array'`` otherwise. + mode: string, optional + Specifies how the generated code will be delimited. ``mode`` can be one + of ``'plain'``, ``'inline'``, ``'equation'`` or ``'equation*'``. If + ``mode`` is set to ``'plain'``, then the resulting code will not be + delimited at all (this is the default). If ``mode`` is set to + ``'inline'`` then inline LaTeX ``$...$`` will be used. If ``mode`` is + set to ``'equation'`` or ``'equation*'``, the resulting code will be + enclosed in the ``equation`` or ``equation*`` environment (remember to + import ``amsmath`` for ``equation*``), unless the ``itex`` option is + set. In the latter case, the ``$$...$$`` syntax is used. + mul_symbol : string or None, optional + The symbol to use for multiplication. Can be one of ``None``, + ``'ldot'``, ``'dot'``, or ``'times'``. + order: string, optional + Any of the supported monomial orderings (currently ``'lex'``, + ``'grlex'``, or ``'grevlex'``), ``'old'``, and ``'none'``. This + parameter does nothing for `~.Mul` objects. Setting order to ``'old'`` + uses the compatibility ordering for ``~.Add`` defined in Printer. For + very large expressions, set the ``order`` keyword to ``'none'`` if + speed is a concern. + symbol_names : dictionary of strings mapped to symbols, optional + Dictionary of symbols and the custom strings they should be emitted as. + root_notation : boolean, optional + If set to ``False``, exponents of the form 1/n are printed in fractonal + form. Default is ``True``, to print exponent in root form. + mat_symbol_style : string, optional + Can be either ``'plain'`` (default) or ``'bold'``. If set to + ``'bold'``, a `~.MatrixSymbol` A will be printed as ``\mathbf{A}``, + otherwise as ``A``. + imaginary_unit : string, optional + String to use for the imaginary unit. Defined options are ``'i'`` + (default) and ``'j'``. Adding ``r`` or ``t`` in front gives ``\mathrm`` + or ``\text``, so ``'ri'`` leads to ``\mathrm{i}`` which gives + `\mathrm{i}`. + gothic_re_im : boolean, optional + If set to ``True``, `\Re` and `\Im` is used for ``re`` and ``im``, respectively. + The default is ``False`` leading to `\operatorname{re}` and `\operatorname{im}`. + decimal_separator : string, optional + Specifies what separator to use to separate the whole and fractional parts of a + floating point number as in `2.5` for the default, ``period`` or `2{,}5` + when ``comma`` is specified. Lists, sets, and tuple are printed with semicolon + separating the elements when ``comma`` is chosen. For example, [1; 2; 3] when + ``comma`` is chosen and [1,2,3] for when ``period`` is chosen. + parenthesize_super : boolean, optional + If set to ``False``, superscripted expressions will not be parenthesized when + powered. Default is ``True``, which parenthesizes the expression when powered. + min: Integer or None, optional + Sets the lower bound for the exponent to print floating point numbers in + fixed-point format. + max: Integer or None, optional + Sets the upper bound for the exponent to print floating point numbers in + fixed-point format. + diff_operator: string, optional + String to use for differential operator. Default is ``'d'``, to print in italic + form. ``'rd'``, ``'td'`` are shortcuts for ``\mathrm{d}`` and ``\text{d}``. + adjoint_style: string, optional + String to use for the adjoint symbol. Defined options are ``'dagger'`` + (default),``'star'``, and ``'hermitian'``. + + Notes + ===== + + Not using a print statement for printing, results in double backslashes for + latex commands since that's the way Python escapes backslashes in strings. + + >>> from sympy import latex, Rational + >>> from sympy.abc import tau + >>> latex((2*tau)**Rational(7,2)) + '8 \\sqrt{2} \\tau^{\\frac{7}{2}}' + >>> print(latex((2*tau)**Rational(7,2))) + 8 \sqrt{2} \tau^{\frac{7}{2}} + + Examples + ======== + + >>> from sympy import latex, pi, sin, asin, Integral, Matrix, Rational, log + >>> from sympy.abc import x, y, mu, r, tau + + Basic usage: + + >>> print(latex((2*tau)**Rational(7,2))) + 8 \sqrt{2} \tau^{\frac{7}{2}} + + ``mode`` and ``itex`` options: + + >>> print(latex((2*mu)**Rational(7,2), mode='plain')) + 8 \sqrt{2} \mu^{\frac{7}{2}} + >>> print(latex((2*tau)**Rational(7,2), mode='inline')) + $8 \sqrt{2} \tau^{7 / 2}$ + >>> print(latex((2*mu)**Rational(7,2), mode='equation*')) + \begin{equation*}8 \sqrt{2} \mu^{\frac{7}{2}}\end{equation*} + >>> print(latex((2*mu)**Rational(7,2), mode='equation')) + \begin{equation}8 \sqrt{2} \mu^{\frac{7}{2}}\end{equation} + >>> print(latex((2*mu)**Rational(7,2), mode='equation', itex=True)) + $$8 \sqrt{2} \mu^{\frac{7}{2}}$$ + >>> print(latex((2*mu)**Rational(7,2), mode='plain')) + 8 \sqrt{2} \mu^{\frac{7}{2}} + >>> print(latex((2*tau)**Rational(7,2), mode='inline')) + $8 \sqrt{2} \tau^{7 / 2}$ + >>> print(latex((2*mu)**Rational(7,2), mode='equation*')) + \begin{equation*}8 \sqrt{2} \mu^{\frac{7}{2}}\end{equation*} + >>> print(latex((2*mu)**Rational(7,2), mode='equation')) + \begin{equation}8 \sqrt{2} \mu^{\frac{7}{2}}\end{equation} + >>> print(latex((2*mu)**Rational(7,2), mode='equation', itex=True)) + $$8 \sqrt{2} \mu^{\frac{7}{2}}$$ + + Fraction options: + + >>> print(latex((2*tau)**Rational(7,2), fold_frac_powers=True)) + 8 \sqrt{2} \tau^{7/2} + >>> print(latex((2*tau)**sin(Rational(7,2)))) + \left(2 \tau\right)^{\sin{\left(\frac{7}{2} \right)}} + >>> print(latex((2*tau)**sin(Rational(7,2)), fold_func_brackets=True)) + \left(2 \tau\right)^{\sin {\frac{7}{2}}} + >>> print(latex(3*x**2/y)) + \frac{3 x^{2}}{y} + >>> print(latex(3*x**2/y, fold_short_frac=True)) + 3 x^{2} / y + >>> print(latex(Integral(r, r)/2/pi, long_frac_ratio=2)) + \frac{\int r\, dr}{2 \pi} + >>> print(latex(Integral(r, r)/2/pi, long_frac_ratio=0)) + \frac{1}{2 \pi} \int r\, dr + + Multiplication options: + + >>> print(latex((2*tau)**sin(Rational(7,2)), mul_symbol="times")) + \left(2 \times \tau\right)^{\sin{\left(\frac{7}{2} \right)}} + + Trig options: + + >>> print(latex(asin(Rational(7,2)))) + \operatorname{asin}{\left(\frac{7}{2} \right)} + >>> print(latex(asin(Rational(7,2)), inv_trig_style="full")) + \arcsin{\left(\frac{7}{2} \right)} + >>> print(latex(asin(Rational(7,2)), inv_trig_style="power")) + \sin^{-1}{\left(\frac{7}{2} \right)} + + Matrix options: + + >>> print(latex(Matrix(2, 1, [x, y]))) + \left[\begin{matrix}x\\y\end{matrix}\right] + >>> print(latex(Matrix(2, 1, [x, y]), mat_str = "array")) + \left[\begin{array}{c}x\\y\end{array}\right] + >>> print(latex(Matrix(2, 1, [x, y]), mat_delim="(")) + \left(\begin{matrix}x\\y\end{matrix}\right) + + Custom printing of symbols: + + >>> print(latex(x**2, symbol_names={x: 'x_i'})) + x_i^{2} + + Logarithms: + + >>> print(latex(log(10))) + \log{\left(10 \right)} + >>> print(latex(log(10), ln_notation=True)) + \ln{\left(10 \right)} + + ``latex()`` also supports the builtin container types :class:`list`, + :class:`tuple`, and :class:`dict`: + + >>> print(latex([2/x, y], mode='inline')) + $\left[ 2 / x, \ y\right]$ + + Unsupported types are rendered as monospaced plaintext: + + >>> print(latex(int)) + \mathtt{\text{}} + >>> print(latex("plain % text")) + \mathtt{\text{plain \% text}} + + See :ref:`printer_method_example` for an example of how to override + this behavior for your own types by implementing ``_latex``. + + .. versionchanged:: 1.7.0 + Unsupported types no longer have their ``str`` representation treated as valid latex. + + """ + return LatexPrinter(settings).doprint(expr) + + +def print_latex(expr, **settings): + """Prints LaTeX representation of the given expression. Takes the same + settings as ``latex()``.""" + + print(latex(expr, **settings)) + + +def multiline_latex(lhs, rhs, terms_per_line=1, environment="align*", use_dots=False, **settings): + r""" + This function generates a LaTeX equation with a multiline right-hand side + in an ``align*``, ``eqnarray`` or ``IEEEeqnarray`` environment. + + Parameters + ========== + + lhs : Expr + Left-hand side of equation + + rhs : Expr + Right-hand side of equation + + terms_per_line : integer, optional + Number of terms per line to print. Default is 1. + + environment : "string", optional + Which LaTeX wnvironment to use for the output. Options are "align*" + (default), "eqnarray", and "IEEEeqnarray". + + use_dots : boolean, optional + If ``True``, ``\\dots`` is added to the end of each line. Default is ``False``. + + Examples + ======== + + >>> from sympy import multiline_latex, symbols, sin, cos, exp, log, I + >>> x, y, alpha = symbols('x y alpha') + >>> expr = sin(alpha*y) + exp(I*alpha) - cos(log(y)) + >>> print(multiline_latex(x, expr)) + \begin{align*} + x = & e^{i \alpha} \\ + & + \sin{\left(\alpha y \right)} \\ + & - \cos{\left(\log{\left(y \right)} \right)} + \end{align*} + + Using at most two terms per line: + >>> print(multiline_latex(x, expr, 2)) + \begin{align*} + x = & e^{i \alpha} + \sin{\left(\alpha y \right)} \\ + & - \cos{\left(\log{\left(y \right)} \right)} + \end{align*} + + Using ``eqnarray`` and dots: + >>> print(multiline_latex(x, expr, terms_per_line=2, environment="eqnarray", use_dots=True)) + \begin{eqnarray} + x & = & e^{i \alpha} + \sin{\left(\alpha y \right)} \dots\nonumber\\ + & & - \cos{\left(\log{\left(y \right)} \right)} + \end{eqnarray} + + Using ``IEEEeqnarray``: + >>> print(multiline_latex(x, expr, environment="IEEEeqnarray")) + \begin{IEEEeqnarray}{rCl} + x & = & e^{i \alpha} \nonumber\\ + & & + \sin{\left(\alpha y \right)} \nonumber\\ + & & - \cos{\left(\log{\left(y \right)} \right)} + \end{IEEEeqnarray} + + Notes + ===== + + All optional parameters from ``latex`` can also be used. + + """ + + # Based on code from https://github.com/sympy/sympy/issues/3001 + l = LatexPrinter(**settings) + if environment == "eqnarray": + result = r'\begin{eqnarray}' + '\n' + first_term = '& = &' + nonumber = r'\nonumber' + end_term = '\n\\end{eqnarray}' + doubleet = True + elif environment == "IEEEeqnarray": + result = r'\begin{IEEEeqnarray}{rCl}' + '\n' + first_term = '& = &' + nonumber = r'\nonumber' + end_term = '\n\\end{IEEEeqnarray}' + doubleet = True + elif environment == "align*": + result = r'\begin{align*}' + '\n' + first_term = '= &' + nonumber = '' + end_term = '\n\\end{align*}' + doubleet = False + else: + raise ValueError("Unknown environment: {}".format(environment)) + dots = '' + if use_dots: + dots=r'\dots' + terms = rhs.as_ordered_terms() + n_terms = len(terms) + term_count = 1 + for i in range(n_terms): + term = terms[i] + term_start = '' + term_end = '' + sign = '+' + if term_count > terms_per_line: + if doubleet: + term_start = '& & ' + else: + term_start = '& ' + term_count = 1 + if term_count == terms_per_line: + # End of line + if i < n_terms-1: + # There are terms remaining + term_end = dots + nonumber + r'\\' + '\n' + else: + term_end = '' + + if term.as_ordered_factors()[0] == -1: + term = -1*term + sign = r'-' + if i == 0: # beginning + if sign == '+': + sign = '' + result += r'{:s} {:s}{:s} {:s} {:s}'.format(l.doprint(lhs), + first_term, sign, l.doprint(term), term_end) + else: + result += r'{:s}{:s} {:s} {:s}'.format(term_start, sign, + l.doprint(term), term_end) + term_count += 1 + result += end_term + return result diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/llvmjitcode.py b/.venv/lib/python3.13/site-packages/sympy/printing/llvmjitcode.py new file mode 100644 index 0000000000000000000000000000000000000000..0e657f3b854be62fb4b6b4be96f82579b718f6ca --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/llvmjitcode.py @@ -0,0 +1,490 @@ +''' +Use llvmlite to create executable functions from SymPy expressions + +This module requires llvmlite (https://github.com/numba/llvmlite). +''' + +import ctypes + +from sympy.external import import_module +from sympy.printing.printer import Printer +from sympy.core.singleton import S +from sympy.tensor.indexed import IndexedBase +from sympy.utilities.decorator import doctest_depends_on + +llvmlite = import_module('llvmlite') +if llvmlite: + ll = import_module('llvmlite.ir').ir + llvm = import_module('llvmlite.binding').binding + llvm.initialize() + llvm.initialize_native_target() + llvm.initialize_native_asmprinter() + + +__doctest_requires__ = {('llvm_callable'): ['llvmlite']} + + +class LLVMJitPrinter(Printer): + '''Convert expressions to LLVM IR''' + def __init__(self, module, builder, fn, *args, **kwargs): + self.func_arg_map = kwargs.pop("func_arg_map", {}) + if not llvmlite: + raise ImportError("llvmlite is required for LLVMJITPrinter") + super().__init__(*args, **kwargs) + self.fp_type = ll.DoubleType() + self.module = module + self.builder = builder + self.fn = fn + self.ext_fn = {} # keep track of wrappers to external functions + self.tmp_var = {} + + def _add_tmp_var(self, name, value): + self.tmp_var[name] = value + + def _print_Number(self, n): + return ll.Constant(self.fp_type, float(n)) + + def _print_Integer(self, expr): + return ll.Constant(self.fp_type, float(expr.p)) + + def _print_Symbol(self, s): + val = self.tmp_var.get(s) + if not val: + # look up parameter with name s + val = self.func_arg_map.get(s) + if not val: + raise LookupError("Symbol not found: %s" % s) + return val + + def _print_Pow(self, expr): + base0 = self._print(expr.base) + if expr.exp == S.NegativeOne: + return self.builder.fdiv(ll.Constant(self.fp_type, 1.0), base0) + if expr.exp == S.Half: + fn = self.ext_fn.get("sqrt") + if not fn: + fn_type = ll.FunctionType(self.fp_type, [self.fp_type]) + fn = ll.Function(self.module, fn_type, "sqrt") + self.ext_fn["sqrt"] = fn + return self.builder.call(fn, [base0], "sqrt") + if expr.exp == 2: + return self.builder.fmul(base0, base0) + + exp0 = self._print(expr.exp) + fn = self.ext_fn.get("pow") + if not fn: + fn_type = ll.FunctionType(self.fp_type, [self.fp_type, self.fp_type]) + fn = ll.Function(self.module, fn_type, "pow") + self.ext_fn["pow"] = fn + return self.builder.call(fn, [base0, exp0], "pow") + + def _print_Mul(self, expr): + nodes = [self._print(a) for a in expr.args] + e = nodes[0] + for node in nodes[1:]: + e = self.builder.fmul(e, node) + return e + + def _print_Add(self, expr): + nodes = [self._print(a) for a in expr.args] + e = nodes[0] + for node in nodes[1:]: + e = self.builder.fadd(e, node) + return e + + # TODO - assumes all called functions take one double precision argument. + # Should have a list of math library functions to validate this. + def _print_Function(self, expr): + name = expr.func.__name__ + e0 = self._print(expr.args[0]) + fn = self.ext_fn.get(name) + if not fn: + fn_type = ll.FunctionType(self.fp_type, [self.fp_type]) + fn = ll.Function(self.module, fn_type, name) + self.ext_fn[name] = fn + return self.builder.call(fn, [e0], name) + + def emptyPrinter(self, expr): + raise TypeError("Unsupported type for LLVM JIT conversion: %s" + % type(expr)) + + +# Used when parameters are passed by array. Often used in callbacks to +# handle a variable number of parameters. +class LLVMJitCallbackPrinter(LLVMJitPrinter): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _print_Indexed(self, expr): + array, idx = self.func_arg_map[expr.base] + offset = int(expr.indices[0].evalf()) + array_ptr = self.builder.gep(array, [ll.Constant(ll.IntType(32), offset)]) + fp_array_ptr = self.builder.bitcast(array_ptr, ll.PointerType(self.fp_type)) + value = self.builder.load(fp_array_ptr) + return value + + def _print_Symbol(self, s): + val = self.tmp_var.get(s) + if val: + return val + + array, idx = self.func_arg_map.get(s, [None, 0]) + if not array: + raise LookupError("Symbol not found: %s" % s) + array_ptr = self.builder.gep(array, [ll.Constant(ll.IntType(32), idx)]) + fp_array_ptr = self.builder.bitcast(array_ptr, + ll.PointerType(self.fp_type)) + value = self.builder.load(fp_array_ptr) + return value + + +# ensure lifetime of the execution engine persists (else call to compiled +# function will seg fault) +exe_engines = [] + +# ensure names for generated functions are unique +link_names = set() +current_link_suffix = 0 + + +class LLVMJitCode: + def __init__(self, signature): + self.signature = signature + self.fp_type = ll.DoubleType() + self.module = ll.Module('mod1') + self.fn = None + self.llvm_arg_types = [] + self.llvm_ret_type = self.fp_type + self.param_dict = {} # map symbol name to LLVM function argument + self.link_name = '' + + def _from_ctype(self, ctype): + if ctype == ctypes.c_int: + return ll.IntType(32) + if ctype == ctypes.c_double: + return self.fp_type + if ctype == ctypes.POINTER(ctypes.c_double): + return ll.PointerType(self.fp_type) + if ctype == ctypes.c_void_p: + return ll.PointerType(ll.IntType(32)) + if ctype == ctypes.py_object: + return ll.PointerType(ll.IntType(32)) + + print("Unhandled ctype = %s" % str(ctype)) + + def _create_args(self, func_args): + """Create types for function arguments""" + self.llvm_ret_type = self._from_ctype(self.signature.ret_type) + self.llvm_arg_types = \ + [self._from_ctype(a) for a in self.signature.arg_ctypes] + + def _create_function_base(self): + """Create function with name and type signature""" + global current_link_suffix + default_link_name = 'jit_func' + current_link_suffix += 1 + self.link_name = default_link_name + str(current_link_suffix) + link_names.add(self.link_name) + + fn_type = ll.FunctionType(self.llvm_ret_type, self.llvm_arg_types) + self.fn = ll.Function(self.module, fn_type, name=self.link_name) + + def _create_param_dict(self, func_args): + """Mapping of symbolic values to function arguments""" + for i, a in enumerate(func_args): + self.fn.args[i].name = str(a) + self.param_dict[a] = self.fn.args[i] + + def _create_function(self, expr): + """Create function body and return LLVM IR""" + bb_entry = self.fn.append_basic_block('entry') + builder = ll.IRBuilder(bb_entry) + + lj = LLVMJitPrinter(self.module, builder, self.fn, + func_arg_map=self.param_dict) + + ret = self._convert_expr(lj, expr) + lj.builder.ret(self._wrap_return(lj, ret)) + + strmod = str(self.module) + return strmod + + def _wrap_return(self, lj, vals): + # Return a single double if there is one return value, + # else return a tuple of doubles. + + # Don't wrap return value in this case + if self.signature.ret_type == ctypes.c_double: + return vals[0] + + # Use this instead of a real PyObject* + void_ptr = ll.PointerType(ll.IntType(32)) + + # Create a wrapped double: PyObject* PyFloat_FromDouble(double v) + wrap_type = ll.FunctionType(void_ptr, [self.fp_type]) + wrap_fn = ll.Function(lj.module, wrap_type, "PyFloat_FromDouble") + + wrapped_vals = [lj.builder.call(wrap_fn, [v]) for v in vals] + if len(vals) == 1: + final_val = wrapped_vals[0] + else: + # Create a tuple: PyObject* PyTuple_Pack(Py_ssize_t n, ...) + + # This should be Py_ssize_t + tuple_arg_types = [ll.IntType(32)] + + tuple_arg_types.extend([void_ptr]*len(vals)) + tuple_type = ll.FunctionType(void_ptr, tuple_arg_types) + tuple_fn = ll.Function(lj.module, tuple_type, "PyTuple_Pack") + + tuple_args = [ll.Constant(ll.IntType(32), len(wrapped_vals))] + tuple_args.extend(wrapped_vals) + + final_val = lj.builder.call(tuple_fn, tuple_args) + + return final_val + + def _convert_expr(self, lj, expr): + try: + # Match CSE return data structure. + if len(expr) == 2: + tmp_exprs = expr[0] + final_exprs = expr[1] + if len(final_exprs) != 1 and self.signature.ret_type == ctypes.c_double: + raise NotImplementedError("Return of multiple expressions not supported for this callback") + for name, e in tmp_exprs: + val = lj._print(e) + lj._add_tmp_var(name, val) + except TypeError: + final_exprs = [expr] + + vals = [lj._print(e) for e in final_exprs] + + return vals + + def _compile_function(self, strmod): + llmod = llvm.parse_assembly(strmod) + + pmb = llvm.create_pass_manager_builder() + pmb.opt_level = 2 + pass_manager = llvm.create_module_pass_manager() + pmb.populate(pass_manager) + + pass_manager.run(llmod) + + target_machine = \ + llvm.Target.from_default_triple().create_target_machine() + exe_eng = llvm.create_mcjit_compiler(llmod, target_machine) + exe_eng.finalize_object() + exe_engines.append(exe_eng) + + if False: + print("Assembly") + print(target_machine.emit_assembly(llmod)) + + fptr = exe_eng.get_function_address(self.link_name) + + return fptr + + +class LLVMJitCodeCallback(LLVMJitCode): + def __init__(self, signature): + super().__init__(signature) + + def _create_param_dict(self, func_args): + for i, a in enumerate(func_args): + if isinstance(a, IndexedBase): + self.param_dict[a] = (self.fn.args[i], i) + self.fn.args[i].name = str(a) + else: + self.param_dict[a] = (self.fn.args[self.signature.input_arg], + i) + + def _create_function(self, expr): + """Create function body and return LLVM IR""" + bb_entry = self.fn.append_basic_block('entry') + builder = ll.IRBuilder(bb_entry) + + lj = LLVMJitCallbackPrinter(self.module, builder, self.fn, + func_arg_map=self.param_dict) + + ret = self._convert_expr(lj, expr) + + if self.signature.ret_arg: + output_fp_ptr = builder.bitcast(self.fn.args[self.signature.ret_arg], + ll.PointerType(self.fp_type)) + for i, val in enumerate(ret): + index = ll.Constant(ll.IntType(32), i) + output_array_ptr = builder.gep(output_fp_ptr, [index]) + builder.store(val, output_array_ptr) + builder.ret(ll.Constant(ll.IntType(32), 0)) # return success + else: + lj.builder.ret(self._wrap_return(lj, ret)) + + strmod = str(self.module) + return strmod + + +class CodeSignature: + def __init__(self, ret_type): + self.ret_type = ret_type + self.arg_ctypes = [] + + # Input argument array element index + self.input_arg = 0 + + # For the case output value is referenced through a parameter rather + # than the return value + self.ret_arg = None + + +def _llvm_jit_code(args, expr, signature, callback_type): + """Create a native code function from a SymPy expression""" + if callback_type is None: + jit = LLVMJitCode(signature) + else: + jit = LLVMJitCodeCallback(signature) + + jit._create_args(args) + jit._create_function_base() + jit._create_param_dict(args) + strmod = jit._create_function(expr) + if False: + print("LLVM IR") + print(strmod) + fptr = jit._compile_function(strmod) + return fptr + + +@doctest_depends_on(modules=('llvmlite', 'scipy')) +def llvm_callable(args, expr, callback_type=None): + '''Compile function from a SymPy expression + + Expressions are evaluated using double precision arithmetic. + Some single argument math functions (exp, sin, cos, etc.) are supported + in expressions. + + Parameters + ========== + + args : List of Symbol + Arguments to the generated function. Usually the free symbols in + the expression. Currently each one is assumed to convert to + a double precision scalar. + expr : Expr, or (Replacements, Expr) as returned from 'cse' + Expression to compile. + callback_type : string + Create function with signature appropriate to use as a callback. + Currently supported: + 'scipy.integrate' + 'scipy.integrate.test' + 'cubature' + + Returns + ======= + + Compiled function that can evaluate the expression. + + Examples + ======== + + >>> import sympy.printing.llvmjitcode as jit + >>> from sympy.abc import a + >>> e = a*a + a + 1 + >>> e1 = jit.llvm_callable([a], e) + >>> e.subs(a, 1.1) # Evaluate via substitution + 3.31000000000000 + >>> e1(1.1) # Evaluate using JIT-compiled code + 3.3100000000000005 + + + Callbacks for integration functions can be JIT compiled. + + >>> import sympy.printing.llvmjitcode as jit + >>> from sympy.abc import a + >>> from sympy import integrate + >>> from scipy.integrate import quad + >>> e = a*a + >>> e1 = jit.llvm_callable([a], e, callback_type='scipy.integrate') + >>> integrate(e, (a, 0.0, 2.0)) + 2.66666666666667 + >>> quad(e1, 0.0, 2.0)[0] + 2.66666666666667 + + The 'cubature' callback is for the Python wrapper around the + cubature package ( https://github.com/saullocastro/cubature ) + and ( http://ab-initio.mit.edu/wiki/index.php/Cubature ) + + There are two signatures for the SciPy integration callbacks. + The first ('scipy.integrate') is the function to be passed to the + integration routine, and will pass the signature checks. + The second ('scipy.integrate.test') is only useful for directly calling + the function using ctypes variables. It will not pass the signature checks + for scipy.integrate. + + The return value from the cse module can also be compiled. This + can improve the performance of the compiled function. If multiple + expressions are given to cse, the compiled function returns a tuple. + The 'cubature' callback handles multiple expressions (set `fdim` + to match in the integration call.) + + >>> import sympy.printing.llvmjitcode as jit + >>> from sympy import cse + >>> from sympy.abc import x,y + >>> e1 = x*x + y*y + >>> e2 = 4*(x*x + y*y) + 8.0 + >>> after_cse = cse([e1,e2]) + >>> after_cse + ([(x0, x**2), (x1, y**2)], [x0 + x1, 4*x0 + 4*x1 + 8.0]) + >>> j1 = jit.llvm_callable([x,y], after_cse) + >>> j1(1.0, 2.0) + (5.0, 28.0) + ''' + + if not llvmlite: + raise ImportError("llvmlite is required for llvmjitcode") + + signature = CodeSignature(ctypes.py_object) + + arg_ctypes = [] + if callback_type is None: + for _ in args: + arg_ctype = ctypes.c_double + arg_ctypes.append(arg_ctype) + elif callback_type in ('scipy.integrate', 'scipy.integrate.test'): + signature.ret_type = ctypes.c_double + arg_ctypes = [ctypes.c_int, ctypes.POINTER(ctypes.c_double)] + arg_ctypes_formal = [ctypes.c_int, ctypes.c_double] + signature.input_arg = 1 + elif callback_type == 'cubature': + arg_ctypes = [ctypes.c_int, + ctypes.POINTER(ctypes.c_double), + ctypes.c_void_p, + ctypes.c_int, + ctypes.POINTER(ctypes.c_double) + ] + signature.ret_type = ctypes.c_int + signature.input_arg = 1 + signature.ret_arg = 4 + else: + raise ValueError("Unknown callback type: %s" % callback_type) + + signature.arg_ctypes = arg_ctypes + + fptr = _llvm_jit_code(args, expr, signature, callback_type) + + if callback_type and callback_type == 'scipy.integrate': + arg_ctypes = arg_ctypes_formal + + # PYFUNCTYPE holds the GIL which is needed to prevent a segfault when + # calling PyFloat_FromDouble on Python 3.10. Probably it is better to use + # ctypes.c_double when returning a float rather than using ctypes.py_object + # and returning a PyFloat from inside the jitted function (i.e. let ctypes + # handle the conversion from double to PyFloat). + if signature.ret_type == ctypes.py_object: + FUNCTYPE = ctypes.PYFUNCTYPE + else: + FUNCTYPE = ctypes.CFUNCTYPE + + cfunc = FUNCTYPE(signature.ret_type, *arg_ctypes)(fptr) + return cfunc diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/octave.py b/.venv/lib/python3.13/site-packages/sympy/printing/octave.py new file mode 100644 index 0000000000000000000000000000000000000000..2cf2d6a5754668d7a95ef5dc7b27b4864756a9e5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/octave.py @@ -0,0 +1,711 @@ +""" +Octave (and Matlab) code printer + +The `OctaveCodePrinter` converts SymPy expressions into Octave expressions. +It uses a subset of the Octave language for Matlab compatibility. + +A complete code generator, which uses `octave_code` extensively, can be found +in `sympy.utilities.codegen`. The `codegen` module can be used to generate +complete source code files. + +""" + +from __future__ import annotations +from typing import Any + +from sympy.core import Mul, Pow, S, Rational +from sympy.core.mul import _keep_coeff +from sympy.core.numbers import equal_valued +from sympy.printing.codeprinter import CodePrinter +from sympy.printing.precedence import precedence, PRECEDENCE +from re import search + +# List of known functions. First, those that have the same name in +# SymPy and Octave. This is almost certainly incomplete! +known_fcns_src1 = ["sin", "cos", "tan", "cot", "sec", "csc", + "asin", "acos", "acot", "atan", "atan2", "asec", "acsc", + "sinh", "cosh", "tanh", "coth", "csch", "sech", + "asinh", "acosh", "atanh", "acoth", "asech", "acsch", + "erfc", "erfi", "erf", "erfinv", "erfcinv", + "besseli", "besselj", "besselk", "bessely", + "bernoulli", "beta", "euler", "exp", "factorial", "floor", + "fresnelc", "fresnels", "gamma", "harmonic", "log", + "polylog", "sign", "zeta", "legendre"] + +# These functions have different names ("SymPy": "Octave"), more +# generally a mapping to (argument_conditions, octave_function). +known_fcns_src2 = { + "Abs": "abs", + "arg": "angle", # arg/angle ok in Octave but only angle in Matlab + "binomial": "bincoeff", + "ceiling": "ceil", + "chebyshevu": "chebyshevU", + "chebyshevt": "chebyshevT", + "Chi": "coshint", + "Ci": "cosint", + "conjugate": "conj", + "DiracDelta": "dirac", + "Heaviside": "heaviside", + "im": "imag", + "laguerre": "laguerreL", + "LambertW": "lambertw", + "li": "logint", + "loggamma": "gammaln", + "Max": "max", + "Min": "min", + "Mod": "mod", + "polygamma": "psi", + "re": "real", + "RisingFactorial": "pochhammer", + "Shi": "sinhint", + "Si": "sinint", +} + + +class OctaveCodePrinter(CodePrinter): + """ + A printer to convert expressions to strings of Octave/Matlab code. + """ + printmethod = "_octave" + language = "Octave" + + _operators = { + 'and': '&', + 'or': '|', + 'not': '~', + } + + _default_settings: dict[str, Any] = dict(CodePrinter._default_settings, **{ + 'precision': 17, + 'user_functions': {}, + 'contract': True, + 'inline': True, + }) + # Note: contract is for expressing tensors as loops (if True), or just + # assignment (if False). FIXME: this should be looked a more carefully + # for Octave. + + + def __init__(self, settings={}): + super().__init__(settings) + self.known_functions = dict(zip(known_fcns_src1, known_fcns_src1)) + self.known_functions.update(dict(known_fcns_src2)) + userfuncs = settings.get('user_functions', {}) + self.known_functions.update(userfuncs) + + + def _rate_index_position(self, p): + return p*5 + + + def _get_statement(self, codestring): + return "%s;" % codestring + + + def _get_comment(self, text): + return "% {}".format(text) + + + def _declare_number_const(self, name, value): + return "{} = {};".format(name, value) + + + def _format_code(self, lines): + return self.indent_code(lines) + + + def _traverse_matrix_indices(self, mat): + # Octave uses Fortran order (column-major) + rows, cols = mat.shape + return ((i, j) for j in range(cols) for i in range(rows)) + + + def _get_loop_opening_ending(self, indices): + open_lines = [] + close_lines = [] + for i in indices: + # Octave arrays start at 1 and end at dimension + var, start, stop = map(self._print, + [i.label, i.lower + 1, i.upper + 1]) + open_lines.append("for %s = %s:%s" % (var, start, stop)) + close_lines.append("end") + return open_lines, close_lines + + + def _print_Mul(self, expr): + # print complex numbers nicely in Octave + if (expr.is_number and expr.is_imaginary and + (S.ImaginaryUnit*expr).is_Integer): + return "%si" % self._print(-S.ImaginaryUnit*expr) + + # cribbed from str.py + prec = precedence(expr) + + c, e = expr.as_coeff_Mul() + if c < 0: + expr = _keep_coeff(-c, e) + sign = "-" + else: + sign = "" + + a = [] # items in the numerator + b = [] # items that are in the denominator (if any) + + pow_paren = [] # Will collect all pow with more than one base element and exp = -1 + + if self.order not in ('old', 'none'): + args = expr.as_ordered_factors() + else: + # use make_args in case expr was something like -x -> x + args = Mul.make_args(expr) + + # Gather args for numerator/denominator + for item in args: + if (item.is_commutative and item.is_Pow and item.exp.is_Rational + and item.exp.is_negative): + if item.exp != -1: + b.append(Pow(item.base, -item.exp, evaluate=False)) + else: + if len(item.args[0].args) != 1 and isinstance(item.base, Mul): # To avoid situations like #14160 + pow_paren.append(item) + b.append(Pow(item.base, -item.exp)) + elif item.is_Rational and item is not S.Infinity: + if item.p != 1: + a.append(Rational(item.p)) + if item.q != 1: + b.append(Rational(item.q)) + else: + a.append(item) + + a = a or [S.One] + + a_str = [self.parenthesize(x, prec) for x in a] + b_str = [self.parenthesize(x, prec) for x in b] + + # To parenthesize Pow with exp = -1 and having more than one Symbol + for item in pow_paren: + if item.base in b: + b_str[b.index(item.base)] = "(%s)" % b_str[b.index(item.base)] + + # from here it differs from str.py to deal with "*" and ".*" + def multjoin(a, a_str): + # here we probably are assuming the constants will come first + r = a_str[0] + for i in range(1, len(a)): + mulsym = '*' if a[i-1].is_number else '.*' + r = r + mulsym + a_str[i] + return r + + if not b: + return sign + multjoin(a, a_str) + elif len(b) == 1: + divsym = '/' if b[0].is_number else './' + return sign + multjoin(a, a_str) + divsym + b_str[0] + else: + divsym = '/' if all(bi.is_number for bi in b) else './' + return (sign + multjoin(a, a_str) + + divsym + "(%s)" % multjoin(b, b_str)) + + def _print_Relational(self, expr): + lhs_code = self._print(expr.lhs) + rhs_code = self._print(expr.rhs) + op = expr.rel_op + return "{} {} {}".format(lhs_code, op, rhs_code) + + def _print_Pow(self, expr): + powsymbol = '^' if all(x.is_number for x in expr.args) else '.^' + + PREC = precedence(expr) + + if equal_valued(expr.exp, 0.5): + return "sqrt(%s)" % self._print(expr.base) + + if expr.is_commutative: + if equal_valued(expr.exp, -0.5): + sym = '/' if expr.base.is_number else './' + return "1" + sym + "sqrt(%s)" % self._print(expr.base) + if equal_valued(expr.exp, -1): + sym = '/' if expr.base.is_number else './' + return "1" + sym + "%s" % self.parenthesize(expr.base, PREC) + + return '%s%s%s' % (self.parenthesize(expr.base, PREC), powsymbol, + self.parenthesize(expr.exp, PREC)) + + + def _print_MatPow(self, expr): + PREC = precedence(expr) + return '%s^%s' % (self.parenthesize(expr.base, PREC), + self.parenthesize(expr.exp, PREC)) + + def _print_MatrixSolve(self, expr): + PREC = precedence(expr) + return "%s \\ %s" % (self.parenthesize(expr.matrix, PREC), + self.parenthesize(expr.vector, PREC)) + + def _print_Pi(self, expr): + return 'pi' + + + def _print_ImaginaryUnit(self, expr): + return "1i" + + + def _print_Exp1(self, expr): + return "exp(1)" + + + def _print_GoldenRatio(self, expr): + # FIXME: how to do better, e.g., for octave_code(2*GoldenRatio)? + #return self._print((1+sqrt(S(5)))/2) + return "(1+sqrt(5))/2" + + + def _print_Assignment(self, expr): + from sympy.codegen.ast import Assignment + from sympy.functions.elementary.piecewise import Piecewise + from sympy.tensor.indexed import IndexedBase + # Copied from codeprinter, but remove special MatrixSymbol treatment + lhs = expr.lhs + rhs = expr.rhs + # We special case assignments that take multiple lines + if not self._settings["inline"] and isinstance(expr.rhs, Piecewise): + # Here we modify Piecewise so each expression is now + # an Assignment, and then continue on the print. + expressions = [] + conditions = [] + for (e, c) in rhs.args: + expressions.append(Assignment(lhs, e)) + conditions.append(c) + temp = Piecewise(*zip(expressions, conditions)) + return self._print(temp) + if self._settings["contract"] and (lhs.has(IndexedBase) or + rhs.has(IndexedBase)): + # Here we check if there is looping to be done, and if so + # print the required loops. + return self._doprint_loops(rhs, lhs) + else: + lhs_code = self._print(lhs) + rhs_code = self._print(rhs) + return self._get_statement("%s = %s" % (lhs_code, rhs_code)) + + + def _print_Infinity(self, expr): + return 'inf' + + + def _print_NegativeInfinity(self, expr): + return '-inf' + + + def _print_NaN(self, expr): + return 'NaN' + + + def _print_list(self, expr): + return '{' + ', '.join(self._print(a) for a in expr) + '}' + _print_tuple = _print_list + _print_Tuple = _print_list + _print_List = _print_list + + + def _print_BooleanTrue(self, expr): + return "true" + + + def _print_BooleanFalse(self, expr): + return "false" + + + def _print_bool(self, expr): + return str(expr).lower() + + + # Could generate quadrature code for definite Integrals? + #_print_Integral = _print_not_supported + + + def _print_MatrixBase(self, A): + # Handle zero dimensions: + if (A.rows, A.cols) == (0, 0): + return '[]' + elif S.Zero in A.shape: + return 'zeros(%s, %s)' % (A.rows, A.cols) + elif (A.rows, A.cols) == (1, 1): + # Octave does not distinguish between scalars and 1x1 matrices + return self._print(A[0, 0]) + return "[%s]" % "; ".join(" ".join([self._print(a) for a in A[r, :]]) + for r in range(A.rows)) + + + def _print_SparseRepMatrix(self, A): + from sympy.matrices import Matrix + L = A.col_list() + # make row vectors of the indices and entries + I = Matrix([[k[0] + 1 for k in L]]) + J = Matrix([[k[1] + 1 for k in L]]) + AIJ = Matrix([[k[2] for k in L]]) + return "sparse(%s, %s, %s, %s, %s)" % (self._print(I), self._print(J), + self._print(AIJ), A.rows, A.cols) + + + def _print_MatrixElement(self, expr): + return self.parenthesize(expr.parent, PRECEDENCE["Atom"], strict=True) \ + + '(%s, %s)' % (expr.i + 1, expr.j + 1) + + + def _print_MatrixSlice(self, expr): + def strslice(x, lim): + l = x[0] + 1 + h = x[1] + step = x[2] + lstr = self._print(l) + hstr = 'end' if h == lim else self._print(h) + if step == 1: + if l == 1 and h == lim: + return ':' + if l == h: + return lstr + else: + return lstr + ':' + hstr + else: + return ':'.join((lstr, self._print(step), hstr)) + return (self._print(expr.parent) + '(' + + strslice(expr.rowslice, expr.parent.shape[0]) + ', ' + + strslice(expr.colslice, expr.parent.shape[1]) + ')') + + + def _print_Indexed(self, expr): + inds = [ self._print(i) for i in expr.indices ] + return "%s(%s)" % (self._print(expr.base.label), ", ".join(inds)) + + + def _print_KroneckerDelta(self, expr): + prec = PRECEDENCE["Pow"] + return "double(%s == %s)" % tuple(self.parenthesize(x, prec) + for x in expr.args) + + def _print_HadamardProduct(self, expr): + return '.*'.join([self.parenthesize(arg, precedence(expr)) + for arg in expr.args]) + + def _print_HadamardPower(self, expr): + PREC = precedence(expr) + return '.**'.join([ + self.parenthesize(expr.base, PREC), + self.parenthesize(expr.exp, PREC) + ]) + + def _print_Identity(self, expr): + shape = expr.shape + if len(shape) == 2 and shape[0] == shape[1]: + shape = [shape[0]] + s = ", ".join(self._print(n) for n in shape) + return "eye(" + s + ")" + + def _print_lowergamma(self, expr): + # Octave implements regularized incomplete gamma function + return "(gammainc({1}, {0}).*gamma({0}))".format( + self._print(expr.args[0]), self._print(expr.args[1])) + + + def _print_uppergamma(self, expr): + return "(gammainc({1}, {0}, 'upper').*gamma({0}))".format( + self._print(expr.args[0]), self._print(expr.args[1])) + + + def _print_sinc(self, expr): + #Note: Divide by pi because Octave implements normalized sinc function. + return "sinc(%s)" % self._print(expr.args[0]/S.Pi) + + + def _print_hankel1(self, expr): + return "besselh(%s, 1, %s)" % (self._print(expr.order), + self._print(expr.argument)) + + + def _print_hankel2(self, expr): + return "besselh(%s, 2, %s)" % (self._print(expr.order), + self._print(expr.argument)) + + + # Note: as of 2015, Octave doesn't have spherical Bessel functions + def _print_jn(self, expr): + from sympy.functions import sqrt, besselj + x = expr.argument + expr2 = sqrt(S.Pi/(2*x))*besselj(expr.order + S.Half, x) + return self._print(expr2) + + + def _print_yn(self, expr): + from sympy.functions import sqrt, bessely + x = expr.argument + expr2 = sqrt(S.Pi/(2*x))*bessely(expr.order + S.Half, x) + return self._print(expr2) + + + def _print_airyai(self, expr): + return "airy(0, %s)" % self._print(expr.args[0]) + + + def _print_airyaiprime(self, expr): + return "airy(1, %s)" % self._print(expr.args[0]) + + + def _print_airybi(self, expr): + return "airy(2, %s)" % self._print(expr.args[0]) + + + def _print_airybiprime(self, expr): + return "airy(3, %s)" % self._print(expr.args[0]) + + + def _print_expint(self, expr): + mu, x = expr.args + if mu != 1: + return self._print_not_supported(expr) + return "expint(%s)" % self._print(x) + + + def _one_or_two_reversed_args(self, expr): + assert len(expr.args) <= 2 + return '{name}({args})'.format( + name=self.known_functions[expr.__class__.__name__], + args=", ".join([self._print(x) for x in reversed(expr.args)]) + ) + + + _print_DiracDelta = _print_LambertW = _one_or_two_reversed_args + + + def _nested_binary_math_func(self, expr): + return '{name}({arg1}, {arg2})'.format( + name=self.known_functions[expr.__class__.__name__], + arg1=self._print(expr.args[0]), + arg2=self._print(expr.func(*expr.args[1:])) + ) + + _print_Max = _print_Min = _nested_binary_math_func + + + def _print_Piecewise(self, expr): + if expr.args[-1].cond != True: + # We need the last conditional to be a True, otherwise the resulting + # function may not return a result. + raise ValueError("All Piecewise expressions must contain an " + "(expr, True) statement to be used as a default " + "condition. Without one, the generated " + "expression may not evaluate to anything under " + "some condition.") + lines = [] + if self._settings["inline"]: + # Express each (cond, expr) pair in a nested Horner form: + # (condition) .* (expr) + (not cond) .* () + # Expressions that result in multiple statements won't work here. + ecpairs = ["({0}).*({1}) + (~({0})).*(".format + (self._print(c), self._print(e)) + for e, c in expr.args[:-1]] + elast = "%s" % self._print(expr.args[-1].expr) + pw = " ...\n".join(ecpairs) + elast + ")"*len(ecpairs) + # Note: current need these outer brackets for 2*pw. Would be + # nicer to teach parenthesize() to do this for us when needed! + return "(" + pw + ")" + else: + for i, (e, c) in enumerate(expr.args): + if i == 0: + lines.append("if (%s)" % self._print(c)) + elif i == len(expr.args) - 1 and c == True: + lines.append("else") + else: + lines.append("elseif (%s)" % self._print(c)) + code0 = self._print(e) + lines.append(code0) + if i == len(expr.args) - 1: + lines.append("end") + return "\n".join(lines) + + + def _print_zeta(self, expr): + if len(expr.args) == 1: + return "zeta(%s)" % self._print(expr.args[0]) + else: + # Matlab two argument zeta is not equivalent to SymPy's + return self._print_not_supported(expr) + + + def indent_code(self, code): + """Accepts a string of code or a list of code lines""" + + # code mostly copied from ccode + if isinstance(code, str): + code_lines = self.indent_code(code.splitlines(True)) + return ''.join(code_lines) + + tab = " " + inc_regex = ('^function ', '^if ', '^elseif ', '^else$', '^for ') + dec_regex = ('^end$', '^elseif ', '^else$') + + # pre-strip left-space from the code + code = [ line.lstrip(' \t') for line in code ] + + increase = [ int(any(search(re, line) for re in inc_regex)) + for line in code ] + decrease = [ int(any(search(re, line) for re in dec_regex)) + for line in code ] + + pretty = [] + level = 0 + for n, line in enumerate(code): + if line in ('', '\n'): + pretty.append(line) + continue + level -= decrease[n] + pretty.append("%s%s" % (tab*level, line)) + level += increase[n] + return pretty + + +def octave_code(expr, assign_to=None, **settings): + r"""Converts `expr` to a string of Octave (or Matlab) code. + + The string uses a subset of the Octave language for Matlab compatibility. + + Parameters + ========== + + expr : Expr + A SymPy expression to be converted. + assign_to : optional + When given, the argument is used as the name of the variable to which + the expression is assigned. Can be a string, ``Symbol``, + ``MatrixSymbol``, or ``Indexed`` type. This can be helpful for + expressions that generate multi-line statements. + precision : integer, optional + The precision for numbers such as pi [default=16]. + user_functions : dict, optional + A dictionary where keys are ``FunctionClass`` instances and values are + their string representations. Alternatively, the dictionary value can + be a list of tuples i.e. [(argument_test, cfunction_string)]. See + below for examples. + human : bool, optional + If True, the result is a single string that may contain some constant + declarations for the number symbols. If False, the same information is + returned in a tuple of (symbols_to_declare, not_supported_functions, + code_text). [default=True]. + contract: bool, optional + If True, ``Indexed`` instances are assumed to obey tensor contraction + rules and the corresponding nested loops over indices are generated. + Setting contract=False will not generate loops, instead the user is + responsible to provide values for the indices in the code. + [default=True]. + inline: bool, optional + If True, we try to create single-statement code instead of multiple + statements. [default=True]. + + Examples + ======== + + >>> from sympy import octave_code, symbols, sin, pi + >>> x = symbols('x') + >>> octave_code(sin(x).series(x).removeO()) + 'x.^5/120 - x.^3/6 + x' + + >>> from sympy import Rational, ceiling + >>> x, y, tau = symbols("x, y, tau") + >>> octave_code((2*tau)**Rational(7, 2)) + '8*sqrt(2)*tau.^(7/2)' + + Note that element-wise (Hadamard) operations are used by default between + symbols. This is because its very common in Octave to write "vectorized" + code. It is harmless if the values are scalars. + + >>> octave_code(sin(pi*x*y), assign_to="s") + 's = sin(pi*x.*y);' + + If you need a matrix product "*" or matrix power "^", you can specify the + symbol as a ``MatrixSymbol``. + + >>> from sympy import Symbol, MatrixSymbol + >>> n = Symbol('n', integer=True, positive=True) + >>> A = MatrixSymbol('A', n, n) + >>> octave_code(3*pi*A**3) + '(3*pi)*A^3' + + This class uses several rules to decide which symbol to use a product. + Pure numbers use "*", Symbols use ".*" and MatrixSymbols use "*". + A HadamardProduct can be used to specify componentwise multiplication ".*" + of two MatrixSymbols. There is currently there is no easy way to specify + scalar symbols, so sometimes the code might have some minor cosmetic + issues. For example, suppose x and y are scalars and A is a Matrix, then + while a human programmer might write "(x^2*y)*A^3", we generate: + + >>> octave_code(x**2*y*A**3) + '(x.^2.*y)*A^3' + + Matrices are supported using Octave inline notation. When using + ``assign_to`` with matrices, the name can be specified either as a string + or as a ``MatrixSymbol``. The dimensions must align in the latter case. + + >>> from sympy import Matrix, MatrixSymbol + >>> mat = Matrix([[x**2, sin(x), ceiling(x)]]) + >>> octave_code(mat, assign_to='A') + 'A = [x.^2 sin(x) ceil(x)];' + + ``Piecewise`` expressions are implemented with logical masking by default. + Alternatively, you can pass "inline=False" to use if-else conditionals. + Note that if the ``Piecewise`` lacks a default term, represented by + ``(expr, True)`` then an error will be thrown. This is to prevent + generating an expression that may not evaluate to anything. + + >>> from sympy import Piecewise + >>> pw = Piecewise((x + 1, x > 0), (x, True)) + >>> octave_code(pw, assign_to=tau) + 'tau = ((x > 0).*(x + 1) + (~(x > 0)).*(x));' + + Note that any expression that can be generated normally can also exist + inside a Matrix: + + >>> mat = Matrix([[x**2, pw, sin(x)]]) + >>> octave_code(mat, assign_to='A') + 'A = [x.^2 ((x > 0).*(x + 1) + (~(x > 0)).*(x)) sin(x)];' + + Custom printing can be defined for certain types by passing a dictionary of + "type" : "function" to the ``user_functions`` kwarg. Alternatively, the + dictionary value can be a list of tuples i.e., [(argument_test, + cfunction_string)]. This can be used to call a custom Octave function. + + >>> from sympy import Function + >>> f = Function('f') + >>> g = Function('g') + >>> custom_functions = { + ... "f": "existing_octave_fcn", + ... "g": [(lambda x: x.is_Matrix, "my_mat_fcn"), + ... (lambda x: not x.is_Matrix, "my_fcn")] + ... } + >>> mat = Matrix([[1, x]]) + >>> octave_code(f(x) + g(x) + g(mat), user_functions=custom_functions) + 'existing_octave_fcn(x) + my_fcn(x) + my_mat_fcn([1 x])' + + Support for loops is provided through ``Indexed`` types. With + ``contract=True`` these expressions will be turned into loops, whereas + ``contract=False`` will just print the assignment expression that should be + looped over: + + >>> from sympy import Eq, IndexedBase, Idx + >>> len_y = 5 + >>> y = IndexedBase('y', shape=(len_y,)) + >>> t = IndexedBase('t', shape=(len_y,)) + >>> Dy = IndexedBase('Dy', shape=(len_y-1,)) + >>> i = Idx('i', len_y-1) + >>> e = Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i])) + >>> octave_code(e.rhs, assign_to=e.lhs, contract=False) + 'Dy(i) = (y(i + 1) - y(i))./(t(i + 1) - t(i));' + """ + return OctaveCodePrinter(settings).doprint(expr, assign_to) + + +def print_octave_code(expr, **settings): + """Prints the Octave (or Matlab) representation of the given expression. + + See `octave_code` for the meaning of the optional arguments. + """ + print(octave_code(expr, **settings)) diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/precedence.py b/.venv/lib/python3.13/site-packages/sympy/printing/precedence.py new file mode 100644 index 0000000000000000000000000000000000000000..d22d5746aeee51bddcf273d4575c30c3c27db71a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/precedence.py @@ -0,0 +1,180 @@ +"""A module providing information about the necessity of brackets""" + + +# Default precedence values for some basic types +PRECEDENCE = { + "Lambda": 1, + "Xor": 10, + "Or": 20, + "And": 30, + "Relational": 35, + "Add": 40, + "Mul": 50, + "Pow": 60, + "Func": 70, + "Not": 100, + "Atom": 1000, + "BitwiseOr": 36, + "BitwiseXor": 37, + "BitwiseAnd": 38 +} + +# A dictionary assigning precedence values to certain classes. These values are +# treated like they were inherited, so not every single class has to be named +# here. +# Do not use this with printers other than StrPrinter +PRECEDENCE_VALUES = { + "Equivalent": PRECEDENCE["Xor"], + "Xor": PRECEDENCE["Xor"], + "Implies": PRECEDENCE["Xor"], + "Or": PRECEDENCE["Or"], + "And": PRECEDENCE["And"], + "Add": PRECEDENCE["Add"], + "Pow": PRECEDENCE["Pow"], + "Relational": PRECEDENCE["Relational"], + "Sub": PRECEDENCE["Add"], + "Not": PRECEDENCE["Not"], + "Function" : PRECEDENCE["Func"], + "NegativeInfinity": PRECEDENCE["Add"], + "MatAdd": PRECEDENCE["Add"], + "MatPow": PRECEDENCE["Pow"], + "MatrixSolve": PRECEDENCE["Mul"], + "Mod": PRECEDENCE["Mul"], + "TensAdd": PRECEDENCE["Add"], + # As soon as `TensMul` is a subclass of `Mul`, remove this: + "TensMul": PRECEDENCE["Mul"], + "HadamardProduct": PRECEDENCE["Mul"], + "HadamardPower": PRECEDENCE["Pow"], + "KroneckerProduct": PRECEDENCE["Mul"], + "Equality": PRECEDENCE["Mul"], + "Unequality": PRECEDENCE["Mul"], +} + +# Sometimes it's not enough to assign a fixed precedence value to a +# class. Then a function can be inserted in this dictionary that takes +# an instance of this class as argument and returns the appropriate +# precedence value. + +# Precedence functions + + +def precedence_Mul(item): + from sympy.core.function import Function + if any(hasattr(arg, 'precedence') and isinstance(arg, Function) and + arg.precedence < PRECEDENCE["Mul"] for arg in item.args): + return PRECEDENCE["Mul"] + + if item.could_extract_minus_sign(): + return PRECEDENCE["Add"] + return PRECEDENCE["Mul"] + + +def precedence_Rational(item): + if item.p < 0: + return PRECEDENCE["Add"] + return PRECEDENCE["Mul"] + + +def precedence_Integer(item): + if item.p < 0: + return PRECEDENCE["Add"] + return PRECEDENCE["Atom"] + + +def precedence_Float(item): + if item < 0: + return PRECEDENCE["Add"] + return PRECEDENCE["Atom"] + + +def precedence_PolyElement(item): + if item.is_generator: + return PRECEDENCE["Atom"] + elif item.is_ground: + return precedence(item.coeff(1)) + elif item.is_term: + return PRECEDENCE["Mul"] + else: + return PRECEDENCE["Add"] + + +def precedence_FracElement(item): + if item.denom == 1: + return precedence_PolyElement(item.numer) + else: + return PRECEDENCE["Mul"] + + +def precedence_UnevaluatedExpr(item): + return precedence(item.args[0]) - 0.5 + + +PRECEDENCE_FUNCTIONS = { + "Integer": precedence_Integer, + "Mul": precedence_Mul, + "Rational": precedence_Rational, + "Float": precedence_Float, + "PolyElement": precedence_PolyElement, + "FracElement": precedence_FracElement, + "UnevaluatedExpr": precedence_UnevaluatedExpr, +} + + +def precedence(item): + """Returns the precedence of a given object. + + This is the precedence for StrPrinter. + """ + if hasattr(item, "precedence"): + return item.precedence + if not isinstance(item, type): + for i in type(item).mro(): + n = i.__name__ + if n in PRECEDENCE_FUNCTIONS: + return PRECEDENCE_FUNCTIONS[n](item) + elif n in PRECEDENCE_VALUES: + return PRECEDENCE_VALUES[n] + return PRECEDENCE["Atom"] + + +PRECEDENCE_TRADITIONAL = PRECEDENCE.copy() +PRECEDENCE_TRADITIONAL['Integral'] = PRECEDENCE["Mul"] +PRECEDENCE_TRADITIONAL['Sum'] = PRECEDENCE["Mul"] +PRECEDENCE_TRADITIONAL['Product'] = PRECEDENCE["Mul"] +PRECEDENCE_TRADITIONAL['Limit'] = PRECEDENCE["Mul"] +PRECEDENCE_TRADITIONAL['Derivative'] = PRECEDENCE["Mul"] +PRECEDENCE_TRADITIONAL['TensorProduct'] = PRECEDENCE["Mul"] +PRECEDENCE_TRADITIONAL['Transpose'] = PRECEDENCE["Pow"] +PRECEDENCE_TRADITIONAL['Adjoint'] = PRECEDENCE["Pow"] +PRECEDENCE_TRADITIONAL['Dot'] = PRECEDENCE["Mul"] - 1 +PRECEDENCE_TRADITIONAL['Cross'] = PRECEDENCE["Mul"] - 1 +PRECEDENCE_TRADITIONAL['Gradient'] = PRECEDENCE["Mul"] - 1 +PRECEDENCE_TRADITIONAL['Divergence'] = PRECEDENCE["Mul"] - 1 +PRECEDENCE_TRADITIONAL['Curl'] = PRECEDENCE["Mul"] - 1 +PRECEDENCE_TRADITIONAL['Laplacian'] = PRECEDENCE["Mul"] - 1 +PRECEDENCE_TRADITIONAL['Union'] = PRECEDENCE['Xor'] +PRECEDENCE_TRADITIONAL['Intersection'] = PRECEDENCE['Xor'] +PRECEDENCE_TRADITIONAL['Complement'] = PRECEDENCE['Xor'] +PRECEDENCE_TRADITIONAL['SymmetricDifference'] = PRECEDENCE['Xor'] +PRECEDENCE_TRADITIONAL['ProductSet'] = PRECEDENCE['Xor'] +PRECEDENCE_TRADITIONAL['DotProduct'] = PRECEDENCE_TRADITIONAL['Dot'] + + +def precedence_traditional(item): + """Returns the precedence of a given object according to the + traditional rules of mathematics. + + This is the precedence for the LaTeX and pretty printer. + """ + # Integral, Sum, Product, Limit have the precedence of Mul in LaTeX, + # the precedence of Atom for other printers: + from sympy.core.expr import UnevaluatedExpr + + if isinstance(item, UnevaluatedExpr): + return precedence_traditional(item.args[0]) + + n = item.__class__.__name__ + if n in PRECEDENCE_TRADITIONAL: + return PRECEDENCE_TRADITIONAL[n] + + return precedence(item) diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/printer.py b/.venv/lib/python3.13/site-packages/sympy/printing/printer.py new file mode 100644 index 0000000000000000000000000000000000000000..0c0a6970920cf0928ad330ed9a3ea4291107a29d --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/printer.py @@ -0,0 +1,432 @@ +"""Printing subsystem driver + +SymPy's printing system works the following way: Any expression can be +passed to a designated Printer who then is responsible to return an +adequate representation of that expression. + +**The basic concept is the following:** + +1. Let the object print itself if it knows how. +2. Take the best fitting method defined in the printer. +3. As fall-back use the emptyPrinter method for the printer. + +Which Method is Responsible for Printing? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The whole printing process is started by calling ``.doprint(expr)`` on the printer +which you want to use. This method looks for an appropriate method which can +print the given expression in the given style that the printer defines. +While looking for the method, it follows these steps: + +1. **Let the object print itself if it knows how.** + + The printer looks for a specific method in every object. The name of that method + depends on the specific printer and is defined under ``Printer.printmethod``. + For example, StrPrinter calls ``_sympystr`` and LatexPrinter calls ``_latex``. + Look at the documentation of the printer that you want to use. + The name of the method is specified there. + + This was the original way of doing printing in sympy. Every class had + its own latex, mathml, str and repr methods, but it turned out that it + is hard to produce a high quality printer, if all the methods are spread + out that far. Therefore all printing code was combined into the different + printers, which works great for built-in SymPy objects, but not that + good for user defined classes where it is inconvenient to patch the + printers. + +2. **Take the best fitting method defined in the printer.** + + The printer loops through expr classes (class + its bases), and tries + to dispatch the work to ``_print_`` + + e.g., suppose we have the following class hierarchy:: + + Basic + | + Atom + | + Number + | + Rational + + then, for ``expr=Rational(...)``, the Printer will try + to call printer methods in the order as shown in the figure below:: + + p._print(expr) + | + |-- p._print_Rational(expr) + | + |-- p._print_Number(expr) + | + |-- p._print_Atom(expr) + | + `-- p._print_Basic(expr) + + if ``._print_Rational`` method exists in the printer, then it is called, + and the result is returned back. Otherwise, the printer tries to call + ``._print_Number`` and so on. + +3. **As a fall-back use the emptyPrinter method for the printer.** + + As fall-back ``self.emptyPrinter`` will be called with the expression. If + not defined in the Printer subclass this will be the same as ``str(expr)``. + +.. _printer_example: + +Example of Custom Printer +^^^^^^^^^^^^^^^^^^^^^^^^^ + +In the example below, we have a printer which prints the derivative of a function +in a shorter form. + +.. code-block:: python + + from sympy.core.symbol import Symbol + from sympy.printing.latex import LatexPrinter, print_latex + from sympy.core.function import UndefinedFunction, Function + + + class MyLatexPrinter(LatexPrinter): + \"\"\"Print derivative of a function of symbols in a shorter form. + \"\"\" + def _print_Derivative(self, expr): + function, *vars = expr.args + if not isinstance(type(function), UndefinedFunction) or \\ + not all(isinstance(i, Symbol) for i in vars): + return super()._print_Derivative(expr) + + # If you want the printer to work correctly for nested + # expressions then use self._print() instead of str() or latex(). + # See the example of nested modulo below in the custom printing + # method section. + return "{}_{{{}}}".format( + self._print(Symbol(function.func.__name__)), + ''.join(self._print(i) for i in vars)) + + + def print_my_latex(expr): + \"\"\" Most of the printers define their own wrappers for print(). + These wrappers usually take printer settings. Our printer does not have + any settings. + \"\"\" + print(MyLatexPrinter().doprint(expr)) + + + y = Symbol("y") + x = Symbol("x") + f = Function("f") + expr = f(x, y).diff(x, y) + + # Print the expression using the normal latex printer and our custom + # printer. + print_latex(expr) + print_my_latex(expr) + +The output of the code above is:: + + \\frac{\\partial^{2}}{\\partial x\\partial y} f{\\left(x,y \\right)} + f_{xy} + +.. _printer_method_example: + +Example of Custom Printing Method +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In the example below, the latex printing of the modulo operator is modified. +This is done by overriding the method ``_latex`` of ``Mod``. + +>>> from sympy import Symbol, Mod, Integer, print_latex + +>>> # Always use printer._print() +>>> class ModOp(Mod): +... def _latex(self, printer): +... a, b = [printer._print(i) for i in self.args] +... return r"\\operatorname{Mod}{\\left(%s, %s\\right)}" % (a, b) + +Comparing the output of our custom operator to the builtin one: + +>>> x = Symbol('x') +>>> m = Symbol('m') +>>> print_latex(Mod(x, m)) +x \\bmod m +>>> print_latex(ModOp(x, m)) +\\operatorname{Mod}{\\left(x, m\\right)} + +Common mistakes +~~~~~~~~~~~~~~~ +It's important to always use ``self._print(obj)`` to print subcomponents of +an expression when customizing a printer. Mistakes include: + +1. Using ``self.doprint(obj)`` instead: + + >>> # This example does not work properly, as only the outermost call may use + >>> # doprint. + >>> class ModOpModeWrong(Mod): + ... def _latex(self, printer): + ... a, b = [printer.doprint(i) for i in self.args] + ... return r"\\operatorname{Mod}{\\left(%s, %s\\right)}" % (a, b) + + This fails when the ``mode`` argument is passed to the printer: + + >>> print_latex(ModOp(x, m), mode='inline') # ok + $\\operatorname{Mod}{\\left(x, m\\right)}$ + >>> print_latex(ModOpModeWrong(x, m), mode='inline') # bad + $\\operatorname{Mod}{\\left($x$, $m$\\right)}$ + +2. Using ``str(obj)`` instead: + + >>> class ModOpNestedWrong(Mod): + ... def _latex(self, printer): + ... a, b = [str(i) for i in self.args] + ... return r"\\operatorname{Mod}{\\left(%s, %s\\right)}" % (a, b) + + This fails on nested objects: + + >>> # Nested modulo. + >>> print_latex(ModOp(ModOp(x, m), Integer(7))) # ok + \\operatorname{Mod}{\\left(\\operatorname{Mod}{\\left(x, m\\right)}, 7\\right)} + >>> print_latex(ModOpNestedWrong(ModOpNestedWrong(x, m), Integer(7))) # bad + \\operatorname{Mod}{\\left(ModOpNestedWrong(x, m), 7\\right)} + +3. Using ``LatexPrinter()._print(obj)`` instead. + + >>> from sympy.printing.latex import LatexPrinter + >>> class ModOpSettingsWrong(Mod): + ... def _latex(self, printer): + ... a, b = [LatexPrinter()._print(i) for i in self.args] + ... return r"\\operatorname{Mod}{\\left(%s, %s\\right)}" % (a, b) + + This causes all the settings to be discarded in the subobjects. As an + example, the ``full_prec`` setting which shows floats to full precision is + ignored: + + >>> from sympy import Float + >>> print_latex(ModOp(Float(1) * x, m), full_prec=True) # ok + \\operatorname{Mod}{\\left(1.00000000000000 x, m\\right)} + >>> print_latex(ModOpSettingsWrong(Float(1) * x, m), full_prec=True) # bad + \\operatorname{Mod}{\\left(1.0 x, m\\right)} + +""" + +from __future__ import annotations +import sys +from typing import Any, Type +import inspect +from contextlib import contextmanager +from functools import cmp_to_key, update_wrapper + +from sympy.core.add import Add +from sympy.core.basic import Basic + +from sympy.core.function import AppliedUndef, UndefinedFunction, Function + + + +@contextmanager +def printer_context(printer, **kwargs): + original = printer._context.copy() + try: + printer._context.update(kwargs) + yield + finally: + printer._context = original + + +class Printer: + """ Generic printer + + Its job is to provide infrastructure for implementing new printers easily. + + If you want to define your custom Printer or your custom printing method + for your custom class then see the example above: printer_example_ . + """ + + _global_settings: dict[str, Any] = {} + + _default_settings: dict[str, Any] = {} + + # must be initialized to pass tests and cannot be set to '| None' to pass mypy + printmethod = None # type: str + + @classmethod + def _get_initial_settings(cls): + settings = cls._default_settings.copy() + for key, val in cls._global_settings.items(): + if key in cls._default_settings: + settings[key] = val + return settings + + def __init__(self, settings=None): + self._str = str + + self._settings = self._get_initial_settings() + self._context = {} # mutable during printing + + if settings is not None: + self._settings.update(settings) + + if len(self._settings) > len(self._default_settings): + for key in self._settings: + if key not in self._default_settings: + raise TypeError("Unknown setting '%s'." % key) + + # _print_level is the number of times self._print() was recursively + # called. See StrPrinter._print_Float() for an example of usage + self._print_level = 0 + + @classmethod + def set_global_settings(cls, **settings): + """Set system-wide printing settings. """ + for key, val in settings.items(): + if val is not None: + cls._global_settings[key] = val + + @property + def order(self): + if 'order' in self._settings: + return self._settings['order'] + else: + raise AttributeError("No order defined.") + + def doprint(self, expr): + """Returns printer's representation for expr (as a string)""" + return self._str(self._print(expr)) + + def _print(self, expr, **kwargs) -> str: + """Internal dispatcher + + Tries the following concepts to print an expression: + 1. Let the object print itself if it knows how. + 2. Take the best fitting method defined in the printer. + 3. As fall-back use the emptyPrinter method for the printer. + """ + self._print_level += 1 + try: + # If the printer defines a name for a printing method + # (Printer.printmethod) and the object knows for itself how it + # should be printed, use that method. + if self.printmethod and hasattr(expr, self.printmethod): + if not (isinstance(expr, type) and issubclass(expr, Basic)): + return getattr(expr, self.printmethod)(self, **kwargs) + + # See if the class of expr is known, or if one of its super + # classes is known, and use that print function + # Exception: ignore the subclasses of Undefined, so that, e.g., + # Function('gamma') does not get dispatched to _print_gamma + classes = type(expr).__mro__ + if AppliedUndef in classes: + classes = classes[classes.index(AppliedUndef):] + if UndefinedFunction in classes: + classes = classes[classes.index(UndefinedFunction):] + # Another exception: if someone subclasses a known function, e.g., + # gamma, and changes the name, then ignore _print_gamma + if Function in classes: + i = classes.index(Function) + classes = tuple(c for c in classes[:i] if \ + c.__name__ == classes[0].__name__ or \ + c.__name__.endswith("Base")) + classes[i:] + for cls in classes: + printmethodname = '_print_' + cls.__name__ + printmethod = getattr(self, printmethodname, None) + if printmethod is not None: + return printmethod(expr, **kwargs) + # Unknown object, fall back to the emptyPrinter. + return self.emptyPrinter(expr) + finally: + self._print_level -= 1 + + def emptyPrinter(self, expr): + return str(expr) + + def _as_ordered_terms(self, expr, order=None): + """A compatibility function for ordering terms in Add. """ + order = order or self.order + + if order == 'old': + return sorted(Add.make_args(expr), key=cmp_to_key(self._compare_pretty)) + elif order == 'none': + return list(expr.args) + else: + return expr.as_ordered_terms(order=order) + + def _compare_pretty(self, a, b): + """return -1, 0, 1 if a is canonically less, equal or + greater than b. This is used when 'order=old' is selected + for printing. This puts Order last, orders Rationals + according to value, puts terms in order wrt the power of + the last power appearing in a term. Ties are broken using + Basic.compare. + """ + from sympy.core.numbers import Rational + from sympy.core.symbol import Wild + from sympy.series.order import Order + if isinstance(a, Order) and not isinstance(b, Order): + return 1 + if not isinstance(a, Order) and isinstance(b, Order): + return -1 + + if isinstance(a, Rational) and isinstance(b, Rational): + l = a.p * b.q + r = b.p * a.q + return (l > r) - (l < r) + else: + p1, p2, p3 = Wild("p1"), Wild("p2"), Wild("p3") + r_a = a.match(p1 * p2**p3) + if r_a and p3 in r_a: + a3 = r_a[p3] + r_b = b.match(p1 * p2**p3) + if r_b and p3 in r_b: + b3 = r_b[p3] + c = Basic.compare(a3, b3) + if c != 0: + return c + + # break ties + return Basic.compare(a, b) + + +class _PrintFunction: + """ + Function wrapper to replace ``**settings`` in the signature with printer defaults + """ + def __init__(self, f, print_cls: Type[Printer]): + # find all the non-setting arguments + params = list(inspect.signature(f).parameters.values()) + assert params.pop(-1).kind == inspect.Parameter.VAR_KEYWORD + self.__other_params = params + + self.__print_cls = print_cls + update_wrapper(self, f) + + def __reduce__(self): + # Since this is used as a decorator, it replaces the original function. + # The default pickling will try to pickle self.__wrapped__ and fail + # because the wrapped function can't be retrieved by name. + return self.__wrapped__.__qualname__ + + def __call__(self, *args, **kwargs): + return self.__wrapped__(*args, **kwargs) + + @property + def __signature__(self) -> inspect.Signature: + settings = self.__print_cls._get_initial_settings() + return inspect.Signature( + parameters=self.__other_params + [ + inspect.Parameter(k, inspect.Parameter.KEYWORD_ONLY, default=v) + for k, v in settings.items() + ], + return_annotation=self.__wrapped__.__annotations__.get('return', inspect.Signature.empty) # type:ignore + ) + + +def print_function(print_cls): + """ A decorator to replace kwargs with the printer settings in __signature__ """ + def decorator(f): + if sys.version_info < (3, 9): + # We have to create a subclass so that `help` actually shows the docstring in older Python versions. + # IPython and Sphinx do not need this, only a raw Python console. + cls = type(f'{f.__qualname__}_PrintFunction', (_PrintFunction,), {"__doc__": f.__doc__}) + else: + cls = _PrintFunction + return cls(f, print_cls) + return decorator diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/pycode.py b/.venv/lib/python3.13/site-packages/sympy/printing/pycode.py new file mode 100644 index 0000000000000000000000000000000000000000..09bdc6788775d409c06bdaae0a43c54544894602 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/pycode.py @@ -0,0 +1,852 @@ +""" +Python code printers + +This module contains Python code printers for plain Python as well as NumPy & SciPy enabled code. +""" +from collections import defaultdict +from itertools import chain +from sympy.core import S +from sympy.core.mod import Mod +from .precedence import precedence +from .codeprinter import CodePrinter + +_kw = { + 'and', 'as', 'assert', 'break', 'class', 'continue', 'def', 'del', 'elif', + 'else', 'except', 'finally', 'for', 'from', 'global', 'if', 'import', 'in', + 'is', 'lambda', 'not', 'or', 'pass', 'raise', 'return', 'try', 'while', + 'with', 'yield', 'None', 'False', 'nonlocal', 'True' +} + +_known_functions = { + 'Abs': 'abs', + 'Min': 'min', + 'Max': 'max', +} +_known_functions_math = { + 'acos': 'acos', + 'acosh': 'acosh', + 'asin': 'asin', + 'asinh': 'asinh', + 'atan': 'atan', + 'atan2': 'atan2', + 'atanh': 'atanh', + 'ceiling': 'ceil', + 'cos': 'cos', + 'cosh': 'cosh', + 'erf': 'erf', + 'erfc': 'erfc', + 'exp': 'exp', + 'expm1': 'expm1', + 'factorial': 'factorial', + 'floor': 'floor', + 'gamma': 'gamma', + 'hypot': 'hypot', + 'isinf': 'isinf', + 'isnan': 'isnan', + 'loggamma': 'lgamma', + 'log': 'log', + 'ln': 'log', + 'log10': 'log10', + 'log1p': 'log1p', + 'log2': 'log2', + 'sin': 'sin', + 'sinh': 'sinh', + 'Sqrt': 'sqrt', + 'tan': 'tan', + 'tanh': 'tanh' +} # Not used from ``math``: [copysign isclose isfinite isinf ldexp frexp pow modf +# radians trunc fmod fsum gcd degrees fabs] +_known_constants_math = { + 'Exp1': 'e', + 'Pi': 'pi', + 'E': 'e', + 'Infinity': 'inf', + 'NaN': 'nan', + 'ComplexInfinity': 'nan' +} + +def _print_known_func(self, expr): + known = self.known_functions[expr.__class__.__name__] + return '{name}({args})'.format(name=self._module_format(known), + args=', '.join((self._print(arg) for arg in expr.args))) + + +def _print_known_const(self, expr): + known = self.known_constants[expr.__class__.__name__] + return self._module_format(known) + + +class AbstractPythonCodePrinter(CodePrinter): + printmethod = "_pythoncode" + language = "Python" + reserved_words = _kw + modules = None # initialized to a set in __init__ + tab = ' ' + _kf = dict(chain( + _known_functions.items(), + [(k, 'math.' + v) for k, v in _known_functions_math.items()] + )) + _kc = {k: 'math.'+v for k, v in _known_constants_math.items()} + _operators = {'and': 'and', 'or': 'or', 'not': 'not'} + _default_settings = dict( + CodePrinter._default_settings, + user_functions={}, + precision=17, + inline=True, + fully_qualified_modules=True, + contract=False, + standard='python3', + ) + + def __init__(self, settings=None): + super().__init__(settings) + + # Python standard handler + std = self._settings['standard'] + if std is None: + import sys + std = 'python{}'.format(sys.version_info.major) + if std != 'python3': + raise ValueError('Only Python 3 is supported.') + self.standard = std + + self.module_imports = defaultdict(set) + + # Known functions and constants handler + self.known_functions = dict(self._kf, **(settings or {}).get( + 'user_functions', {})) + self.known_constants = dict(self._kc, **(settings or {}).get( + 'user_constants', {})) + + def _declare_number_const(self, name, value): + return "%s = %s" % (name, value) + + def _module_format(self, fqn, register=True): + parts = fqn.split('.') + if register and len(parts) > 1: + self.module_imports['.'.join(parts[:-1])].add(parts[-1]) + + if self._settings['fully_qualified_modules']: + return fqn + else: + return fqn.split('(')[0].split('[')[0].split('.')[-1] + + def _format_code(self, lines): + return lines + + def _get_statement(self, codestring): + return "{}".format(codestring) + + def _get_comment(self, text): + return " # {}".format(text) + + def _expand_fold_binary_op(self, op, args): + """ + This method expands a fold on binary operations. + + ``functools.reduce`` is an example of a folded operation. + + For example, the expression + + `A + B + C + D` + + is folded into + + `((A + B) + C) + D` + """ + if len(args) == 1: + return self._print(args[0]) + else: + return "%s(%s, %s)" % ( + self._module_format(op), + self._expand_fold_binary_op(op, args[:-1]), + self._print(args[-1]), + ) + + def _expand_reduce_binary_op(self, op, args): + """ + This method expands a reduction on binary operations. + + Notice: this is NOT the same as ``functools.reduce``. + + For example, the expression + + `A + B + C + D` + + is reduced into: + + `(A + B) + (C + D)` + """ + if len(args) == 1: + return self._print(args[0]) + else: + N = len(args) + Nhalf = N // 2 + return "%s(%s, %s)" % ( + self._module_format(op), + self._expand_reduce_binary_op(args[:Nhalf]), + self._expand_reduce_binary_op(args[Nhalf:]), + ) + + def _print_NaN(self, expr): + return "float('nan')" + + def _print_Infinity(self, expr): + return "float('inf')" + + def _print_NegativeInfinity(self, expr): + return "float('-inf')" + + def _print_ComplexInfinity(self, expr): + return self._print_NaN(expr) + + def _print_Mod(self, expr): + PREC = precedence(expr) + return ('{} % {}'.format(*(self.parenthesize(x, PREC) for x in expr.args))) + + def _print_Piecewise(self, expr): + result = [] + i = 0 + for arg in expr.args: + e = arg.expr + c = arg.cond + if i == 0: + result.append('(') + result.append('(') + result.append(self._print(e)) + result.append(')') + result.append(' if ') + result.append(self._print(c)) + result.append(' else ') + i += 1 + result = result[:-1] + if result[-1] == 'True': + result = result[:-2] + result.append(')') + else: + result.append(' else None)') + return ''.join(result) + + def _print_Relational(self, expr): + "Relational printer for Equality and Unequality" + op = { + '==' :'equal', + '!=' :'not_equal', + '<' :'less', + '<=' :'less_equal', + '>' :'greater', + '>=' :'greater_equal', + } + if expr.rel_op in op: + lhs = self._print(expr.lhs) + rhs = self._print(expr.rhs) + return '({lhs} {op} {rhs})'.format(op=expr.rel_op, lhs=lhs, rhs=rhs) + return super()._print_Relational(expr) + + def _print_ITE(self, expr): + from sympy.functions.elementary.piecewise import Piecewise + return self._print(expr.rewrite(Piecewise)) + + def _print_Sum(self, expr): + loops = ( + 'for {i} in range({a}, {b}+1)'.format( + i=self._print(i), + a=self._print(a), + b=self._print(b)) + for i, a, b in expr.limits[::-1]) + return '(builtins.sum({function} {loops}))'.format( + function=self._print(expr.function), + loops=' '.join(loops)) + + def _print_ImaginaryUnit(self, expr): + return '1j' + + def _print_KroneckerDelta(self, expr): + a, b = expr.args + + return '(1 if {a} == {b} else 0)'.format( + a = self._print(a), + b = self._print(b) + ) + + def _print_MatrixBase(self, expr): + name = expr.__class__.__name__ + func = self.known_functions.get(name, name) + return "%s(%s)" % (func, self._print(expr.tolist())) + + _print_SparseRepMatrix = \ + _print_MutableSparseMatrix = \ + _print_ImmutableSparseMatrix = \ + _print_Matrix = \ + _print_DenseMatrix = \ + _print_MutableDenseMatrix = \ + _print_ImmutableMatrix = \ + _print_ImmutableDenseMatrix = \ + lambda self, expr: self._print_MatrixBase(expr) + + def _indent_codestring(self, codestring): + return '\n'.join([self.tab + line for line in codestring.split('\n')]) + + def _print_FunctionDefinition(self, fd): + body = '\n'.join((self._print(arg) for arg in fd.body)) + return "def {name}({parameters}):\n{body}".format( + name=self._print(fd.name), + parameters=', '.join([self._print(var.symbol) for var in fd.parameters]), + body=self._indent_codestring(body) + ) + + def _print_While(self, whl): + body = '\n'.join((self._print(arg) for arg in whl.body)) + return "while {cond}:\n{body}".format( + cond=self._print(whl.condition), + body=self._indent_codestring(body) + ) + + def _print_Declaration(self, decl): + return '%s = %s' % ( + self._print(decl.variable.symbol), + self._print(decl.variable.value) + ) + + def _print_BreakToken(self, bt): + return 'break' + + def _print_Return(self, ret): + arg, = ret.args + return 'return %s' % self._print(arg) + + def _print_Raise(self, rs): + arg, = rs.args + return 'raise %s' % self._print(arg) + + def _print_RuntimeError_(self, re): + message, = re.args + return "RuntimeError(%s)" % self._print(message) + + def _print_Print(self, prnt): + print_args = ', '.join((self._print(arg) for arg in prnt.print_args)) + from sympy.codegen.ast import none + if prnt.format_string != none: + print_args = '{} % ({}), end=""'.format( + self._print(prnt.format_string), + print_args + ) + if prnt.file != None: # Must be '!= None', cannot be 'is not None' + print_args += ', file=%s' % self._print(prnt.file) + return 'print(%s)' % print_args + + def _print_Stream(self, strm): + if str(strm.name) == 'stdout': + return self._module_format('sys.stdout') + elif str(strm.name) == 'stderr': + return self._module_format('sys.stderr') + else: + return self._print(strm.name) + + def _print_NoneToken(self, arg): + return 'None' + + def _hprint_Pow(self, expr, rational=False, sqrt='math.sqrt'): + """Printing helper function for ``Pow`` + + Notes + ===== + + This preprocesses the ``sqrt`` as math formatter and prints division + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.printing.pycode import PythonCodePrinter + >>> from sympy.abc import x + + Python code printer automatically looks up ``math.sqrt``. + + >>> printer = PythonCodePrinter() + >>> printer._hprint_Pow(sqrt(x), rational=True) + 'x**(1/2)' + >>> printer._hprint_Pow(sqrt(x), rational=False) + 'math.sqrt(x)' + >>> printer._hprint_Pow(1/sqrt(x), rational=True) + 'x**(-1/2)' + >>> printer._hprint_Pow(1/sqrt(x), rational=False) + '1/math.sqrt(x)' + >>> printer._hprint_Pow(1/x, rational=False) + '1/x' + >>> printer._hprint_Pow(1/x, rational=True) + 'x**(-1)' + + Using sqrt from numpy or mpmath + + >>> printer._hprint_Pow(sqrt(x), sqrt='numpy.sqrt') + 'numpy.sqrt(x)' + >>> printer._hprint_Pow(sqrt(x), sqrt='mpmath.sqrt') + 'mpmath.sqrt(x)' + + See Also + ======== + + sympy.printing.str.StrPrinter._print_Pow + """ + PREC = precedence(expr) + + if expr.exp == S.Half and not rational: + func = self._module_format(sqrt) + arg = self._print(expr.base) + return '{func}({arg})'.format(func=func, arg=arg) + + if expr.is_commutative and not rational: + if -expr.exp is S.Half: + func = self._module_format(sqrt) + num = self._print(S.One) + arg = self._print(expr.base) + return f"{num}/{func}({arg})" + if expr.exp is S.NegativeOne: + num = self._print(S.One) + arg = self.parenthesize(expr.base, PREC, strict=False) + return f"{num}/{arg}" + + + base_str = self.parenthesize(expr.base, PREC, strict=False) + exp_str = self.parenthesize(expr.exp, PREC, strict=False) + return "{}**{}".format(base_str, exp_str) + + +class ArrayPrinter: + + def _arrayify(self, indexed): + from sympy.tensor.array.expressions.from_indexed_to_array import convert_indexed_to_array + try: + return convert_indexed_to_array(indexed) + except Exception: + return indexed + + def _get_einsum_string(self, subranks, contraction_indices): + letters = self._get_letter_generator_for_einsum() + contraction_string = "" + counter = 0 + d = {j: min(i) for i in contraction_indices for j in i} + indices = [] + for rank_arg in subranks: + lindices = [] + for i in range(rank_arg): + if counter in d: + lindices.append(d[counter]) + else: + lindices.append(counter) + counter += 1 + indices.append(lindices) + mapping = {} + letters_free = [] + letters_dum = [] + for i in indices: + for j in i: + if j not in mapping: + l = next(letters) + mapping[j] = l + else: + l = mapping[j] + contraction_string += l + if j in d: + if l not in letters_dum: + letters_dum.append(l) + else: + letters_free.append(l) + contraction_string += "," + contraction_string = contraction_string[:-1] + return contraction_string, letters_free, letters_dum + + def _get_letter_generator_for_einsum(self): + for i in range(97, 123): + yield chr(i) + for i in range(65, 91): + yield chr(i) + raise ValueError("out of letters") + + def _print_ArrayTensorProduct(self, expr): + letters = self._get_letter_generator_for_einsum() + contraction_string = ",".join(["".join([next(letters) for j in range(i)]) for i in expr.subranks]) + return '%s("%s", %s)' % ( + self._module_format(self._module + "." + self._einsum), + contraction_string, + ", ".join([self._print(arg) for arg in expr.args]) + ) + + def _print_ArrayContraction(self, expr): + from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct + base = expr.expr + contraction_indices = expr.contraction_indices + + if isinstance(base, ArrayTensorProduct): + elems = ",".join(["%s" % (self._print(arg)) for arg in base.args]) + ranks = base.subranks + else: + elems = self._print(base) + ranks = [len(base.shape)] + + contraction_string, letters_free, letters_dum = self._get_einsum_string(ranks, contraction_indices) + + if not contraction_indices: + return self._print(base) + if isinstance(base, ArrayTensorProduct): + elems = ",".join(["%s" % (self._print(arg)) for arg in base.args]) + else: + elems = self._print(base) + return "%s(\"%s\", %s)" % ( + self._module_format(self._module + "." + self._einsum), + "{}->{}".format(contraction_string, "".join(sorted(letters_free))), + elems, + ) + + def _print_ArrayDiagonal(self, expr): + from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct + diagonal_indices = list(expr.diagonal_indices) + if isinstance(expr.expr, ArrayTensorProduct): + subranks = expr.expr.subranks + elems = expr.expr.args + else: + subranks = expr.subranks + elems = [expr.expr] + diagonal_string, letters_free, letters_dum = self._get_einsum_string(subranks, diagonal_indices) + elems = [self._print(i) for i in elems] + return '%s("%s", %s)' % ( + self._module_format(self._module + "." + self._einsum), + "{}->{}".format(diagonal_string, "".join(letters_free+letters_dum)), + ", ".join(elems) + ) + + def _print_PermuteDims(self, expr): + return "%s(%s, %s)" % ( + self._module_format(self._module + "." + self._transpose), + self._print(expr.expr), + self._print(expr.permutation.array_form), + ) + + def _print_ArrayAdd(self, expr): + return self._expand_fold_binary_op(self._module + "." + self._add, expr.args) + + def _print_OneArray(self, expr): + return "%s((%s,))" % ( + self._module_format(self._module+ "." + self._ones), + ','.join(map(self._print,expr.args)) + ) + + def _print_ZeroArray(self, expr): + return "%s((%s,))" % ( + self._module_format(self._module+ "." + self._zeros), + ','.join(map(self._print,expr.args)) + ) + + def _print_Assignment(self, expr): + #XXX: maybe this needs to happen at a higher level e.g. at _print or + #doprint? + lhs = self._print(self._arrayify(expr.lhs)) + rhs = self._print(self._arrayify(expr.rhs)) + return "%s = %s" % ( lhs, rhs ) + + def _print_IndexedBase(self, expr): + return self._print_ArraySymbol(expr) + + +class PythonCodePrinter(AbstractPythonCodePrinter): + + def _print_sign(self, e): + return '(0.0 if {e} == 0 else {f}(1, {e}))'.format( + f=self._module_format('math.copysign'), e=self._print(e.args[0])) + + def _print_Not(self, expr): + PREC = precedence(expr) + return self._operators['not'] + ' ' + self.parenthesize(expr.args[0], PREC) + + def _print_IndexedBase(self, expr): + return expr.name + + def _print_Indexed(self, expr): + base = expr.args[0] + index = expr.args[1:] + return "{}[{}]".format(str(base), ", ".join([self._print(ind) for ind in index])) + + def _print_Pow(self, expr, rational=False): + return self._hprint_Pow(expr, rational=rational) + + def _print_Rational(self, expr): + return '{}/{}'.format(expr.p, expr.q) + + def _print_Half(self, expr): + return self._print_Rational(expr) + + def _print_frac(self, expr): + return self._print_Mod(Mod(expr.args[0], 1)) + + def _print_Symbol(self, expr): + + name = super()._print_Symbol(expr) + + if name in self.reserved_words: + if self._settings['error_on_reserved']: + msg = ('This expression includes the symbol "{}" which is a ' + 'reserved keyword in this language.') + raise ValueError(msg.format(name)) + return name + self._settings['reserved_word_suffix'] + elif '{' in name: # Remove curly braces from subscripted variables + return name.replace('{', '').replace('}', '') + else: + return name + + _print_lowergamma = CodePrinter._print_not_supported + _print_uppergamma = CodePrinter._print_not_supported + _print_fresnelc = CodePrinter._print_not_supported + _print_fresnels = CodePrinter._print_not_supported + + +for k in PythonCodePrinter._kf: + setattr(PythonCodePrinter, '_print_%s' % k, _print_known_func) + +for k in _known_constants_math: + setattr(PythonCodePrinter, '_print_%s' % k, _print_known_const) + + +def pycode(expr, **settings): + """ Converts an expr to a string of Python code + + Parameters + ========== + + expr : Expr + A SymPy expression. + fully_qualified_modules : bool + Whether or not to write out full module names of functions + (``math.sin`` vs. ``sin``). default: ``True``. + standard : str or None, optional + Only 'python3' (default) is supported. + This parameter may be removed in the future. + + Examples + ======== + + >>> from sympy import pycode, tan, Symbol + >>> pycode(tan(Symbol('x')) + 1) + 'math.tan(x) + 1' + + """ + return PythonCodePrinter(settings).doprint(expr) + + +from itertools import chain +from sympy.printing.pycode import PythonCodePrinter + +_known_functions_cmath = { + 'exp': 'exp', + 'sqrt': 'sqrt', + 'log': 'log', + 'cos': 'cos', + 'sin': 'sin', + 'tan': 'tan', + 'acos': 'acos', + 'asin': 'asin', + 'atan': 'atan', + 'cosh': 'cosh', + 'sinh': 'sinh', + 'tanh': 'tanh', + 'acosh': 'acosh', + 'asinh': 'asinh', + 'atanh': 'atanh', +} + +_known_constants_cmath = { + 'Pi': 'pi', + 'E': 'e', + 'Infinity': 'inf', + 'NegativeInfinity': '-inf', +} + +class CmathPrinter(PythonCodePrinter): + """ Printer for Python's cmath module """ + printmethod = "_cmathcode" + language = "Python with cmath" + + _kf = dict(chain( + _known_functions_cmath.items() + )) + + _kc = {k: 'cmath.' + v for k, v in _known_constants_cmath.items()} + + def _print_Pow(self, expr, rational=False): + return self._hprint_Pow(expr, rational=rational, sqrt='cmath.sqrt') + + def _print_Float(self, e): + return '{func}({val})'.format(func=self._module_format('cmath.mpf'), val=self._print(e)) + + def _print_known_func(self, expr): + func_name = expr.func.__name__ + if func_name in self._kf: + return f"cmath.{self._kf[func_name]}({', '.join(map(self._print, expr.args))})" + return super()._print_Function(expr) + + def _print_known_const(self, expr): + return self._kc[expr.__class__.__name__] + + def _print_re(self, expr): + """Prints `re(z)` as `z.real`""" + return f"({self._print(expr.args[0])}).real" + + def _print_im(self, expr): + """Prints `im(z)` as `z.imag`""" + return f"({self._print(expr.args[0])}).imag" + + +for k in CmathPrinter._kf: + setattr(CmathPrinter, '_print_%s' % k, CmathPrinter._print_known_func) + +for k in _known_constants_cmath: + setattr(CmathPrinter, '_print_%s' % k, CmathPrinter._print_known_const) + + +_not_in_mpmath = 'log1p log2'.split() +_in_mpmath = [(k, v) for k, v in _known_functions_math.items() if k not in _not_in_mpmath] +_known_functions_mpmath = dict(_in_mpmath, **{ + 'beta': 'beta', + 'frac': 'frac', + 'fresnelc': 'fresnelc', + 'fresnels': 'fresnels', + 'sign': 'sign', + 'loggamma': 'loggamma', + 'hyper': 'hyper', + 'meijerg': 'meijerg', + 'besselj': 'besselj', + 'bessely': 'bessely', + 'besseli': 'besseli', + 'besselk': 'besselk', +}) +_known_constants_mpmath = { + 'Exp1': 'e', + 'Pi': 'pi', + 'GoldenRatio': 'phi', + 'EulerGamma': 'euler', + 'Catalan': 'catalan', + 'NaN': 'nan', + 'Infinity': 'inf', + 'NegativeInfinity': 'ninf' +} + + +def _unpack_integral_limits(integral_expr): + """ helper function for _print_Integral that + - accepts an Integral expression + - returns a tuple of + - a list variables of integration + - a list of tuples of the upper and lower limits of integration + """ + integration_vars = [] + limits = [] + for integration_range in integral_expr.limits: + if len(integration_range) == 3: + integration_var, lower_limit, upper_limit = integration_range + else: + raise NotImplementedError("Only definite integrals are supported") + integration_vars.append(integration_var) + limits.append((lower_limit, upper_limit)) + return integration_vars, limits + + +class MpmathPrinter(PythonCodePrinter): + """ + Lambda printer for mpmath which maintains precision for floats + """ + printmethod = "_mpmathcode" + + language = "Python with mpmath" + + _kf = dict(chain( + _known_functions.items(), + [(k, 'mpmath.' + v) for k, v in _known_functions_mpmath.items()] + )) + _kc = {k: 'mpmath.'+v for k, v in _known_constants_mpmath.items()} + + def _print_Float(self, e): + # XXX: This does not handle setting mpmath.mp.dps. It is assumed that + # the caller of the lambdified function will have set it to sufficient + # precision to match the Floats in the expression. + + # Remove 'mpz' if gmpy is installed. + args = str(tuple(map(int, e._mpf_))) + return '{func}({args})'.format(func=self._module_format('mpmath.mpf'), args=args) + + + def _print_Rational(self, e): + return "{func}({p})/{func}({q})".format( + func=self._module_format('mpmath.mpf'), + q=self._print(e.q), + p=self._print(e.p) + ) + + def _print_Half(self, e): + return self._print_Rational(e) + + def _print_uppergamma(self, e): + return "{}({}, {}, {})".format( + self._module_format('mpmath.gammainc'), + self._print(e.args[0]), + self._print(e.args[1]), + self._module_format('mpmath.inf')) + + def _print_lowergamma(self, e): + return "{}({}, 0, {})".format( + self._module_format('mpmath.gammainc'), + self._print(e.args[0]), + self._print(e.args[1])) + + def _print_log2(self, e): + return '{0}({1})/{0}(2)'.format( + self._module_format('mpmath.log'), self._print(e.args[0])) + + def _print_log1p(self, e): + return '{}({})'.format( + self._module_format('mpmath.log1p'), self._print(e.args[0])) + + def _print_Pow(self, expr, rational=False): + return self._hprint_Pow(expr, rational=rational, sqrt='mpmath.sqrt') + + def _print_Integral(self, e): + integration_vars, limits = _unpack_integral_limits(e) + + return "{}(lambda {}: {}, {})".format( + self._module_format("mpmath.quad"), + ", ".join(map(self._print, integration_vars)), + self._print(e.args[0]), + ", ".join("(%s, %s)" % tuple(map(self._print, l)) for l in limits)) + + + def _print_Derivative_zeta(self, args, seq_orders): + arg, = args + deriv_order, = seq_orders + return '{}({}, derivative={})'.format( + self._module_format('mpmath.zeta'), + self._print(arg), deriv_order + ) + + +for k in MpmathPrinter._kf: + setattr(MpmathPrinter, '_print_%s' % k, _print_known_func) + +for k in _known_constants_mpmath: + setattr(MpmathPrinter, '_print_%s' % k, _print_known_const) + + +class SymPyPrinter(AbstractPythonCodePrinter): + + language = "Python with SymPy" + + _default_settings = dict( + AbstractPythonCodePrinter._default_settings, + strict=False # any class name will per definition be what we target in SymPyPrinter. + ) + + def _print_Function(self, expr): + mod = expr.func.__module__ or '' + return '%s(%s)' % (self._module_format(mod + ('.' if mod else '') + expr.func.__name__), + ', '.join((self._print(arg) for arg in expr.args))) + + def _print_Pow(self, expr, rational=False): + return self._hprint_Pow(expr, rational=rational, sqrt='sympy.sqrt') diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/python.py b/.venv/lib/python3.13/site-packages/sympy/printing/python.py new file mode 100644 index 0000000000000000000000000000000000000000..2f6862574d99db90f289de65144c7122ed2d731a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/python.py @@ -0,0 +1,92 @@ +import keyword as kw +import sympy +from .repr import ReprPrinter +from .str import StrPrinter + +# A list of classes that should be printed using StrPrinter +STRPRINT = ("Add", "Infinity", "Integer", "Mul", "NegativeInfinity", "Pow") + + +class PythonPrinter(ReprPrinter, StrPrinter): + """A printer which converts an expression into its Python interpretation.""" + + def __init__(self, settings=None): + super().__init__(settings) + self.symbols = [] + self.functions = [] + + # Create print methods for classes that should use StrPrinter instead + # of ReprPrinter. + for name in STRPRINT: + f_name = "_print_%s" % name + f = getattr(StrPrinter, f_name) + setattr(PythonPrinter, f_name, f) + + def _print_Function(self, expr): + func = expr.func.__name__ + if not hasattr(sympy, func) and func not in self.functions: + self.functions.append(func) + return StrPrinter._print_Function(self, expr) + + # procedure (!) for defining symbols which have be defined in print_python() + def _print_Symbol(self, expr): + symbol = self._str(expr) + if symbol not in self.symbols: + self.symbols.append(symbol) + return StrPrinter._print_Symbol(self, expr) + + def _print_module(self, expr): + raise ValueError('Modules in the expression are unacceptable') + + +def python(expr, **settings): + """Return Python interpretation of passed expression + (can be passed to the exec() function without any modifications)""" + + printer = PythonPrinter(settings) + exprp = printer.doprint(expr) + + result = '' + # Returning found symbols and functions + renamings = {} + for symbolname in printer.symbols: + # Remove curly braces from subscripted variables + if '{' in symbolname: + newsymbolname = symbolname.replace('{', '').replace('}', '') + renamings[sympy.Symbol(symbolname)] = newsymbolname + else: + newsymbolname = symbolname + + # Escape symbol names that are reserved Python keywords + if kw.iskeyword(newsymbolname): + while True: + newsymbolname += "_" + if (newsymbolname not in printer.symbols and + newsymbolname not in printer.functions): + renamings[sympy.Symbol( + symbolname)] = sympy.Symbol(newsymbolname) + break + result += newsymbolname + ' = Symbol(\'' + symbolname + '\')\n' + + for functionname in printer.functions: + newfunctionname = functionname + # Escape function names that are reserved Python keywords + if kw.iskeyword(newfunctionname): + while True: + newfunctionname += "_" + if (newfunctionname not in printer.symbols and + newfunctionname not in printer.functions): + renamings[sympy.Function( + functionname)] = sympy.Function(newfunctionname) + break + result += newfunctionname + ' = Function(\'' + functionname + '\')\n' + + if renamings: + exprp = expr.subs(renamings) + result += 'e = ' + printer._str(exprp) + return result + + +def print_python(expr, **settings): + """Print output of python() function""" + print(python(expr, **settings)) diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/pytorch.py b/.venv/lib/python3.13/site-packages/sympy/printing/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..0e8ff01856fa1dce7f8e786065a32bb74d987254 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/pytorch.py @@ -0,0 +1,297 @@ + +from sympy.printing.pycode import AbstractPythonCodePrinter, ArrayPrinter +from sympy.matrices.expressions import MatrixExpr +from sympy.core.mul import Mul +from sympy.printing.precedence import PRECEDENCE +from sympy.external import import_module +from sympy.codegen.cfunctions import Sqrt +from sympy import S +from sympy import Integer + +import sympy + +torch = import_module('torch') + + +class TorchPrinter(ArrayPrinter, AbstractPythonCodePrinter): + + printmethod = "_torchcode" + + mapping = { + sympy.Abs: "torch.abs", + sympy.sign: "torch.sign", + + # XXX May raise error for ints. + sympy.ceiling: "torch.ceil", + sympy.floor: "torch.floor", + sympy.log: "torch.log", + sympy.exp: "torch.exp", + Sqrt: "torch.sqrt", + sympy.cos: "torch.cos", + sympy.acos: "torch.acos", + sympy.sin: "torch.sin", + sympy.asin: "torch.asin", + sympy.tan: "torch.tan", + sympy.atan: "torch.atan", + sympy.atan2: "torch.atan2", + # XXX Also may give NaN for complex results. + sympy.cosh: "torch.cosh", + sympy.acosh: "torch.acosh", + sympy.sinh: "torch.sinh", + sympy.asinh: "torch.asinh", + sympy.tanh: "torch.tanh", + sympy.atanh: "torch.atanh", + sympy.Pow: "torch.pow", + + sympy.re: "torch.real", + sympy.im: "torch.imag", + sympy.arg: "torch.angle", + + # XXX May raise error for ints and complexes + sympy.erf: "torch.erf", + sympy.loggamma: "torch.lgamma", + + sympy.Eq: "torch.eq", + sympy.Ne: "torch.ne", + sympy.StrictGreaterThan: "torch.gt", + sympy.StrictLessThan: "torch.lt", + sympy.LessThan: "torch.le", + sympy.GreaterThan: "torch.ge", + + sympy.And: "torch.logical_and", + sympy.Or: "torch.logical_or", + sympy.Not: "torch.logical_not", + sympy.Max: "torch.max", + sympy.Min: "torch.min", + + # Matrices + sympy.MatAdd: "torch.add", + sympy.HadamardProduct: "torch.mul", + sympy.Trace: "torch.trace", + + # XXX May raise error for integer matrices. + sympy.Determinant: "torch.det", + } + + _default_settings = dict( + AbstractPythonCodePrinter._default_settings, + torch_version=None, + requires_grad=False, + dtype="torch.float64", + ) + + def __init__(self, settings=None): + super().__init__(settings) + + version = self._settings['torch_version'] + self.requires_grad = self._settings['requires_grad'] + self.dtype = self._settings['dtype'] + if version is None and torch: + version = torch.__version__ + self.torch_version = version + + def _print_Function(self, expr): + + op = self.mapping.get(type(expr), None) + if op is None: + return super()._print_Basic(expr) + children = [self._print(arg) for arg in expr.args] + if len(children) == 1: + return "%s(%s)" % ( + self._module_format(op), + children[0] + ) + else: + return self._expand_fold_binary_op(op, children) + + # mirrors the tensorflow version + _print_Expr = _print_Function + _print_Application = _print_Function + _print_MatrixExpr = _print_Function + _print_Relational = _print_Function + _print_Not = _print_Function + _print_And = _print_Function + _print_Or = _print_Function + _print_HadamardProduct = _print_Function + _print_Trace = _print_Function + _print_Determinant = _print_Function + + def _print_Inverse(self, expr): + return '{}({})'.format(self._module_format("torch.linalg.inv"), + self._print(expr.args[0])) + + def _print_Transpose(self, expr): + if expr.arg.is_Matrix and expr.arg.shape[0] == expr.arg.shape[1]: + # For square matrices, we can use the .t() method + return "{}({}).t()".format("torch.transpose", self._print(expr.arg)) + else: + # For non-square matrices or more general cases + # transpose first and second dimensions (typical matrix transpose) + return "{}.permute({})".format( + self._print(expr.arg), + ", ".join([str(i) for i in range(len(expr.arg.shape))])[::-1] + ) + + def _print_PermuteDims(self, expr): + return "%s.permute(%s)" % ( + self._print(expr.expr), + ", ".join(str(i) for i in expr.permutation.array_form) + ) + + def _print_Derivative(self, expr): + # this version handles multi-variable and mixed partial derivatives. The tensorflow version does not. + variables = expr.variables + expr_arg = expr.expr + + # Handle multi-variable or repeated derivatives + if len(variables) > 1 or ( + len(variables) == 1 and not isinstance(variables[0], tuple) and variables.count(variables[0]) > 1): + result = self._print(expr_arg) + var_groups = {} + + # Group variables by base symbol + for var in variables: + if isinstance(var, tuple): + base_var, order = var + var_groups[base_var] = var_groups.get(base_var, 0) + order + else: + var_groups[var] = var_groups.get(var, 0) + 1 + + # Apply gradients in sequence + for var, order in var_groups.items(): + for _ in range(order): + result = "torch.autograd.grad({}, {}, create_graph=True)[0]".format(result, self._print(var)) + return result + + # Handle single variable case + if len(variables) == 1: + variable = variables[0] + if isinstance(variable, tuple) and len(variable) == 2: + base_var, order = variable + if not isinstance(order, Integer): raise NotImplementedError("Only integer orders are supported") + result = self._print(expr_arg) + for _ in range(order): + result = "torch.autograd.grad({}, {}, create_graph=True)[0]".format(result, self._print(base_var)) + return result + return "torch.autograd.grad({}, {})[0]".format(self._print(expr_arg), self._print(variable)) + + return self._print(expr_arg) # Empty variables case + + def _print_Piecewise(self, expr): + from sympy import Piecewise + e, cond = expr.args[0].args + if len(expr.args) == 1: + return '{}({}, {}, {})'.format( + self._module_format("torch.where"), + self._print(cond), + self._print(e), + 0) + + return '{}({}, {}, {})'.format( + self._module_format("torch.where"), + self._print(cond), + self._print(e), + self._print(Piecewise(*expr.args[1:]))) + + def _print_Pow(self, expr): + # XXX May raise error for + # int**float or int**complex or float**complex + base, exp = expr.args + if expr.exp == S.Half: + return "{}({})".format( + self._module_format("torch.sqrt"), self._print(base)) + return "{}({}, {})".format( + self._module_format("torch.pow"), + self._print(base), self._print(exp)) + + def _print_MatMul(self, expr): + # Separate matrix and scalar arguments + mat_args = [arg for arg in expr.args if isinstance(arg, MatrixExpr)] + args = [arg for arg in expr.args if arg not in mat_args] + # Handle scalar multipliers if present + if args: + return "%s*%s" % ( + self.parenthesize(Mul.fromiter(args), PRECEDENCE["Mul"]), + self._expand_fold_binary_op("torch.matmul", mat_args) + ) + else: + return self._expand_fold_binary_op("torch.matmul", mat_args) + + def _print_MatPow(self, expr): + return self._expand_fold_binary_op("torch.mm", [expr.base]*expr.exp) + + def _print_MatrixBase(self, expr): + data = "[" + ", ".join(["[" + ", ".join([self._print(j) for j in i]) + "]" for i in expr.tolist()]) + "]" + params = [str(data)] + params.append(f"dtype={self.dtype}") + if self.requires_grad: + params.append("requires_grad=True") + + return "{}({})".format( + self._module_format("torch.tensor"), + ", ".join(params) + ) + + def _print_isnan(self, expr): + return f'torch.isnan({self._print(expr.args[0])})' + + def _print_isinf(self, expr): + return f'torch.isinf({self._print(expr.args[0])})' + + def _print_Identity(self, expr): + if all(dim.is_Integer for dim in expr.shape): + return "{}({})".format( + self._module_format("torch.eye"), + self._print(expr.shape[0]) + ) + else: + # For symbolic dimensions, fall back to a more general approach + return "{}({}, {})".format( + self._module_format("torch.eye"), + self._print(expr.shape[0]), + self._print(expr.shape[1]) + ) + + def _print_ZeroMatrix(self, expr): + return "{}({})".format( + self._module_format("torch.zeros"), + self._print(expr.shape) + ) + + def _print_OneMatrix(self, expr): + return "{}({})".format( + self._module_format("torch.ones"), + self._print(expr.shape) + ) + + def _print_conjugate(self, expr): + return f"{self._module_format('torch.conj')}({self._print(expr.args[0])})" + + def _print_ImaginaryUnit(self, expr): + return "1j" # uses the Python built-in 1j notation for the imaginary unit + + def _print_Heaviside(self, expr): + args = [self._print(expr.args[0]), "0.5"] + if len(expr.args) > 1: + args[1] = self._print(expr.args[1]) + return f"{self._module_format('torch.heaviside')}({args[0]}, {args[1]})" + + def _print_gamma(self, expr): + return f"{self._module_format('torch.special.gamma')}({self._print(expr.args[0])})" + + def _print_polygamma(self, expr): + if expr.args[0] == S.Zero: + return f"{self._module_format('torch.special.digamma')}({self._print(expr.args[1])})" + else: + raise NotImplementedError("PyTorch only supports digamma (0th order polygamma)") + + _module = "torch" + _einsum = "einsum" + _add = "add" + _transpose = "t" + _ones = "ones" + _zeros = "zeros" + +def torch_code(expr, requires_grad=False, dtype="torch.float64", **settings): + printer = TorchPrinter(settings={'requires_grad': requires_grad, 'dtype': dtype}) + return printer.doprint(expr, **settings) diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/repr.py b/.venv/lib/python3.13/site-packages/sympy/printing/repr.py new file mode 100644 index 0000000000000000000000000000000000000000..0a4b756abbab77c3eb0fd77ee1f0bd97382c36fb --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/repr.py @@ -0,0 +1,339 @@ +""" +A Printer for generating executable code. + +The most important function here is srepr that returns a string so that the +relation eval(srepr(expr))=expr holds in an appropriate environment. +""" + +from __future__ import annotations +from typing import Any + +from sympy.core.function import AppliedUndef +from sympy.core.mul import Mul +from mpmath.libmp import repr_dps, to_str as mlib_to_str + +from .printer import Printer, print_function + + +class ReprPrinter(Printer): + printmethod = "_sympyrepr" + + _default_settings: dict[str, Any] = { + "order": None, + "perm_cyclic" : True, + } + + def reprify(self, args, sep): + """ + Prints each item in `args` and joins them with `sep`. + """ + return sep.join([self.doprint(item) for item in args]) + + def emptyPrinter(self, expr): + """ + The fallback printer. + """ + if isinstance(expr, str): + return expr + elif hasattr(expr, "__srepr__"): + return expr.__srepr__() + elif hasattr(expr, "args") and hasattr(expr.args, "__iter__"): + l = [] + for o in expr.args: + l.append(self._print(o)) + return expr.__class__.__name__ + '(%s)' % ', '.join(l) + elif hasattr(expr, "__module__") and hasattr(expr, "__name__"): + return "<'%s.%s'>" % (expr.__module__, expr.__name__) + else: + return str(expr) + + def _print_Add(self, expr, order=None): + args = self._as_ordered_terms(expr, order=order) + args = map(self._print, args) + clsname = type(expr).__name__ + return clsname + "(%s)" % ", ".join(args) + + def _print_Cycle(self, expr): + return expr.__repr__() + + def _print_Permutation(self, expr): + from sympy.combinatorics.permutations import Permutation, Cycle + from sympy.utilities.exceptions import sympy_deprecation_warning + + perm_cyclic = Permutation.print_cyclic + if perm_cyclic is not None: + sympy_deprecation_warning( + f""" + Setting Permutation.print_cyclic is deprecated. Instead use + init_printing(perm_cyclic={perm_cyclic}). + """, + deprecated_since_version="1.6", + active_deprecations_target="deprecated-permutation-print_cyclic", + stacklevel=7, + ) + else: + perm_cyclic = self._settings.get("perm_cyclic", True) + + if perm_cyclic: + if not expr.size: + return 'Permutation()' + # before taking Cycle notation, see if the last element is + # a singleton and move it to the head of the string + s = Cycle(expr)(expr.size - 1).__repr__()[len('Cycle'):] + last = s.rfind('(') + if not last == 0 and ',' not in s[last:]: + s = s[last:] + s[:last] + return 'Permutation%s' %s + else: + s = expr.support() + if not s: + if expr.size < 5: + return 'Permutation(%s)' % str(expr.array_form) + return 'Permutation([], size=%s)' % expr.size + trim = str(expr.array_form[:s[-1] + 1]) + ', size=%s' % expr.size + use = full = str(expr.array_form) + if len(trim) < len(full): + use = trim + return 'Permutation(%s)' % use + + def _print_Function(self, expr): + r = self._print(expr.func) + r += '(%s)' % ', '.join([self._print(a) for a in expr.args]) + return r + + def _print_Heaviside(self, expr): + # Same as _print_Function but uses pargs to suppress default value for + # 2nd arg. + r = self._print(expr.func) + r += '(%s)' % ', '.join([self._print(a) for a in expr.pargs]) + return r + + def _print_FunctionClass(self, expr): + if issubclass(expr, AppliedUndef): + return 'Function(%r)' % (expr.__name__) + else: + return expr.__name__ + + def _print_Half(self, expr): + return 'Rational(1, 2)' + + def _print_RationalConstant(self, expr): + return str(expr) + + def _print_AtomicExpr(self, expr): + return str(expr) + + def _print_NumberSymbol(self, expr): + return str(expr) + + def _print_Integer(self, expr): + return 'Integer(%i)' % expr.p + + def _print_Complexes(self, expr): + return 'Complexes' + + def _print_Integers(self, expr): + return 'Integers' + + def _print_Naturals(self, expr): + return 'Naturals' + + def _print_Naturals0(self, expr): + return 'Naturals0' + + def _print_Rationals(self, expr): + return 'Rationals' + + def _print_Reals(self, expr): + return 'Reals' + + def _print_EmptySet(self, expr): + return 'EmptySet' + + def _print_UniversalSet(self, expr): + return 'UniversalSet' + + def _print_EmptySequence(self, expr): + return 'EmptySequence' + + def _print_list(self, expr): + return "[%s]" % self.reprify(expr, ", ") + + def _print_dict(self, expr): + sep = ", " + dict_kvs = ["%s: %s" % (self.doprint(key), self.doprint(value)) for key, value in expr.items()] + return "{%s}" % sep.join(dict_kvs) + + def _print_set(self, expr): + if not expr: + return "set()" + return "{%s}" % self.reprify(expr, ", ") + + def _print_MatrixBase(self, expr): + # special case for some empty matrices + if (expr.rows == 0) ^ (expr.cols == 0): + return '%s(%s, %s, %s)' % (expr.__class__.__name__, + self._print(expr.rows), + self._print(expr.cols), + self._print([])) + l = [] + for i in range(expr.rows): + l.append([]) + for j in range(expr.cols): + l[-1].append(expr[i, j]) + return '%s(%s)' % (expr.__class__.__name__, self._print(l)) + + def _print_BooleanTrue(self, expr): + return "true" + + def _print_BooleanFalse(self, expr): + return "false" + + def _print_NaN(self, expr): + return "nan" + + def _print_Mul(self, expr, order=None): + if self.order not in ('old', 'none'): + args = expr.as_ordered_factors() + else: + # use make_args in case expr was something like -x -> x + args = Mul.make_args(expr) + + args = map(self._print, args) + clsname = type(expr).__name__ + return clsname + "(%s)" % ", ".join(args) + + def _print_Rational(self, expr): + return 'Rational(%s, %s)' % (self._print(expr.p), self._print(expr.q)) + + def _print_PythonRational(self, expr): + return "%s(%d, %d)" % (expr.__class__.__name__, expr.p, expr.q) + + def _print_Fraction(self, expr): + return 'Fraction(%s, %s)' % (self._print(expr.numerator), self._print(expr.denominator)) + + def _print_Float(self, expr): + r = mlib_to_str(expr._mpf_, repr_dps(expr._prec)) + return "%s('%s', precision=%i)" % (expr.__class__.__name__, r, expr._prec) + + def _print_Sum2(self, expr): + return "Sum2(%s, (%s, %s, %s))" % (self._print(expr.f), self._print(expr.i), + self._print(expr.a), self._print(expr.b)) + + def _print_Str(self, s): + return "%s(%s)" % (s.__class__.__name__, self._print(s.name)) + + def _print_Symbol(self, expr): + d = expr._assumptions_orig + # print the dummy_index like it was an assumption + if expr.is_Dummy: + d = d.copy() + d['dummy_index'] = expr.dummy_index + + if d == {}: + return "%s(%s)" % (expr.__class__.__name__, self._print(expr.name)) + else: + attr = ['%s=%s' % (k, v) for k, v in d.items()] + return "%s(%s, %s)" % (expr.__class__.__name__, + self._print(expr.name), ', '.join(attr)) + + def _print_CoordinateSymbol(self, expr): + d = expr._assumptions.generator + + if d == {}: + return "%s(%s, %s)" % ( + expr.__class__.__name__, + self._print(expr.coord_sys), + self._print(expr.index) + ) + else: + attr = ['%s=%s' % (k, v) for k, v in d.items()] + return "%s(%s, %s, %s)" % ( + expr.__class__.__name__, + self._print(expr.coord_sys), + self._print(expr.index), + ', '.join(attr) + ) + + def _print_Predicate(self, expr): + return "Q.%s" % expr.name + + def _print_AppliedPredicate(self, expr): + # will be changed to just expr.args when args overriding is removed + args = expr._args + return "%s(%s)" % (expr.__class__.__name__, self.reprify(args, ", ")) + + def _print_str(self, expr): + return repr(expr) + + def _print_tuple(self, expr): + if len(expr) == 1: + return "(%s,)" % self._print(expr[0]) + else: + return "(%s)" % self.reprify(expr, ", ") + + def _print_WildFunction(self, expr): + return "%s('%s')" % (expr.__class__.__name__, expr.name) + + def _print_AlgebraicNumber(self, expr): + return "%s(%s, %s)" % (expr.__class__.__name__, + self._print(expr.root), self._print(expr.coeffs())) + + def _print_PolyRing(self, ring): + return "%s(%s, %s, %s)" % (ring.__class__.__name__, + self._print(ring.symbols), self._print(ring.domain), self._print(ring.order)) + + def _print_FracField(self, field): + return "%s(%s, %s, %s)" % (field.__class__.__name__, + self._print(field.symbols), self._print(field.domain), self._print(field.order)) + + def _print_PolyElement(self, poly): + terms = list(poly.terms()) + terms.sort(key=poly.ring.order, reverse=True) + return "%s(%s, %s)" % (poly.__class__.__name__, self._print(poly.ring), self._print(terms)) + + def _print_FracElement(self, frac): + numer_terms = list(frac.numer.terms()) + numer_terms.sort(key=frac.field.order, reverse=True) + denom_terms = list(frac.denom.terms()) + denom_terms.sort(key=frac.field.order, reverse=True) + numer = self._print(numer_terms) + denom = self._print(denom_terms) + return "%s(%s, %s, %s)" % (frac.__class__.__name__, self._print(frac.field), numer, denom) + + def _print_FractionField(self, domain): + cls = domain.__class__.__name__ + field = self._print(domain.field) + return "%s(%s)" % (cls, field) + + def _print_PolynomialRingBase(self, ring): + cls = ring.__class__.__name__ + dom = self._print(ring.domain) + gens = ', '.join(map(self._print, ring.gens)) + order = str(ring.order) + if order != ring.default_order: + orderstr = ", order=" + order + else: + orderstr = "" + return "%s(%s, %s%s)" % (cls, dom, gens, orderstr) + + def _print_DMP(self, p): + cls = p.__class__.__name__ + rep = self._print(p.to_list()) + dom = self._print(p.dom) + return "%s(%s, %s)" % (cls, rep, dom) + + def _print_MonogenicFiniteExtension(self, ext): + # The expanded tree shown by srepr(ext.modulus) + # is not practical. + return "FiniteExtension(%s)" % str(ext.modulus) + + def _print_ExtensionElement(self, f): + rep = self._print(f.rep) + ext = self._print(f.ext) + return "ExtElem(%s, %s)" % (rep, ext) + +@print_function(ReprPrinter) +def srepr(expr, **settings): + """return expr in repr form""" + return ReprPrinter(settings).doprint(expr) diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/rust.py b/.venv/lib/python3.13/site-packages/sympy/printing/rust.py new file mode 100644 index 0000000000000000000000000000000000000000..5bfd481bec6b7350df281accfb9b7a598cf05baa --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/rust.py @@ -0,0 +1,637 @@ +""" +Rust code printer + +The `RustCodePrinter` converts SymPy expressions into Rust expressions. + +A complete code generator, which uses `rust_code` extensively, can be found +in `sympy.utilities.codegen`. The `codegen` module can be used to generate +complete source code files. + +""" + +# Possible Improvement +# +# * make sure we follow Rust Style Guidelines_ +# * make use of pattern matching +# * better support for reference +# * generate generic code and use trait to make sure they have specific methods +# * use crates_ to get more math support +# - num_ +# + BigInt_, BigUint_ +# + Complex_ +# + Rational64_, Rational32_, BigRational_ +# +# .. _crates: https://crates.io/ +# .. _Guidelines: https://github.com/rust-lang/rust/tree/master/src/doc/style +# .. _num: http://rust-num.github.io/num/num/ +# .. _BigInt: http://rust-num.github.io/num/num/bigint/struct.BigInt.html +# .. _BigUint: http://rust-num.github.io/num/num/bigint/struct.BigUint.html +# .. _Complex: http://rust-num.github.io/num/num/complex/struct.Complex.html +# .. _Rational32: http://rust-num.github.io/num/num/rational/type.Rational32.html +# .. _Rational64: http://rust-num.github.io/num/num/rational/type.Rational64.html +# .. _BigRational: http://rust-num.github.io/num/num/rational/type.BigRational.html + +from __future__ import annotations +from functools import reduce +import operator +from typing import Any + +from sympy.codegen.ast import ( + float32, float64, int32, + real, integer, bool_ +) +from sympy.core import S, Rational, Float, Lambda +from sympy.core.expr import Expr +from sympy.core.numbers import equal_valued +from sympy.functions.elementary.integers import ceiling, floor +from sympy.printing.codeprinter import CodePrinter +from sympy.printing.precedence import PRECEDENCE + +# Rust's methods for integer and float can be found at here : +# +# * `Rust - Primitive Type f64 `_ +# * `Rust - Primitive Type i64 `_ +# +# Function Style : +# +# 1. args[0].func(args[1:]), method with arguments +# 2. args[0].func(), method without arguments +# 3. args[1].func(), method without arguments (e.g. (e, x) => x.exp()) +# 4. func(args), function with arguments + +# dictionary mapping SymPy function to (argument_conditions, Rust_function). +# Used in RustCodePrinter._print_Function(self) + +class float_floor(floor): + """ + Same as `sympy.floor`, but mimics the Rust behavior of returning a float rather than an integer + """ + def _eval_is_integer(self): + return False + +class float_ceiling(ceiling): + """ + Same as `sympy.ceiling`, but mimics the Rust behavior of returning a float rather than an integer + """ + def _eval_is_integer(self): + return False + + +function_overrides = { + "floor": (floor, float_floor), + "ceiling": (ceiling, float_ceiling), +} + +# f64 method in Rust +known_functions = { + # "": "is_nan", + # "": "is_infinite", + # "": "is_finite", + # "": "is_normal", + # "": "classify", + "float_floor": "floor", + "float_ceiling": "ceil", + # "": "round", + # "": "trunc", + # "": "fract", + "Abs": "abs", + # "": "signum", + # "": "is_sign_positive", + # "": "is_sign_negative", + # "": "mul_add", + "Pow": [(lambda base, exp: equal_valued(exp, -1), "recip", 2), # 1.0/x + (lambda base, exp: equal_valued(exp, 0.5), "sqrt", 2), # x ** 0.5 + (lambda base, exp: equal_valued(exp, -0.5), "sqrt().recip", 2), # 1/(x ** 0.5) + (lambda base, exp: exp == Rational(1, 3), "cbrt", 2), # x ** (1/3) + (lambda base, exp: equal_valued(base, 2), "exp2", 3), # 2 ** x + (lambda base, exp: exp.is_integer, "powi", 1), # x ** y, for i32 + (lambda base, exp: not exp.is_integer, "powf", 1)], # x ** y, for f64 + "exp": [(lambda exp: True, "exp", 2)], # e ** x + "log": "ln", + # "": "log", # number.log(base) + # "": "log2", + # "": "log10", + # "": "to_degrees", + # "": "to_radians", + "Max": "max", + "Min": "min", + # "": "hypot", # (x**2 + y**2) ** 0.5 + "sin": "sin", + "cos": "cos", + "tan": "tan", + "asin": "asin", + "acos": "acos", + "atan": "atan", + "atan2": "atan2", + # "": "sin_cos", + # "": "exp_m1", # e ** x - 1 + # "": "ln_1p", # ln(1 + x) + "sinh": "sinh", + "cosh": "cosh", + "tanh": "tanh", + "asinh": "asinh", + "acosh": "acosh", + "atanh": "atanh", + "sqrt": "sqrt", # To enable automatic rewrites +} + +# i64 method in Rust +# known_functions_i64 = { +# "": "min_value", +# "": "max_value", +# "": "from_str_radix", +# "": "count_ones", +# "": "count_zeros", +# "": "leading_zeros", +# "": "trainling_zeros", +# "": "rotate_left", +# "": "rotate_right", +# "": "swap_bytes", +# "": "from_be", +# "": "from_le", +# "": "to_be", # to big endian +# "": "to_le", # to little endian +# "": "checked_add", +# "": "checked_sub", +# "": "checked_mul", +# "": "checked_div", +# "": "checked_rem", +# "": "checked_neg", +# "": "checked_shl", +# "": "checked_shr", +# "": "checked_abs", +# "": "saturating_add", +# "": "saturating_sub", +# "": "saturating_mul", +# "": "wrapping_add", +# "": "wrapping_sub", +# "": "wrapping_mul", +# "": "wrapping_div", +# "": "wrapping_rem", +# "": "wrapping_neg", +# "": "wrapping_shl", +# "": "wrapping_shr", +# "": "wrapping_abs", +# "": "overflowing_add", +# "": "overflowing_sub", +# "": "overflowing_mul", +# "": "overflowing_div", +# "": "overflowing_rem", +# "": "overflowing_neg", +# "": "overflowing_shl", +# "": "overflowing_shr", +# "": "overflowing_abs", +# "Pow": "pow", +# "Abs": "abs", +# "sign": "signum", +# "": "is_positive", +# "": "is_negnative", +# } + +# These are the core reserved words in the Rust language. Taken from: +# https://doc.rust-lang.org/reference/keywords.html + +reserved_words = ['abstract', + 'as', + 'async', + 'await', + 'become', + 'box', + 'break', + 'const', + 'continue', + 'crate', + 'do', + 'dyn', + 'else', + 'enum', + 'extern', + 'false', + 'final', + 'fn', + 'for', + 'gen', + 'if', + 'impl', + 'in', + 'let', + 'loop', + 'macro', + 'match', + 'mod', + 'move', + 'mut', + 'override', + 'priv', + 'pub', + 'ref', + 'return', + 'Self', + 'self', + 'static', + 'struct', + 'super', + 'trait', + 'true', + 'try', + 'type', + 'typeof', + 'unsafe', + 'unsized', + 'use', + 'virtual', + 'where', + 'while', + 'yield'] + + +class TypeCast(Expr): + """ + The type casting operator of the Rust language. + """ + + def __init__(self, expr, type_) -> None: + super().__init__() + self.explicit = expr.is_integer and type_ is not integer + self._assumptions = expr._assumptions + if self.explicit: + setattr(self, 'precedence', PRECEDENCE["Func"] + 10) + + @property + def expr(self): + return self.args[0] + + @property + def type_(self): + return self.args[1] + + def sort_key(self, order=None): + return self.args[0].sort_key(order=order) + + +class RustCodePrinter(CodePrinter): + """A printer to convert SymPy expressions to strings of Rust code""" + printmethod = "_rust_code" + language = "Rust" + + type_aliases = { + integer: int32, + real: float64, + } + + type_mappings = { + int32: 'i32', + float32: 'f32', + float64: 'f64', + bool_: 'bool' + } + + _default_settings: dict[str, Any] = dict(CodePrinter._default_settings, **{ + 'precision': 17, + 'user_functions': {}, + 'contract': True, + 'dereference': set(), + }) + + def __init__(self, settings={}): + CodePrinter.__init__(self, settings) + self.known_functions = dict(known_functions) + userfuncs = settings.get('user_functions', {}) + self.known_functions.update(userfuncs) + self._dereference = set(settings.get('dereference', [])) + self.reserved_words = set(reserved_words) + self.function_overrides = function_overrides + + def _rate_index_position(self, p): + return p*5 + + def _get_statement(self, codestring): + return "%s;" % codestring + + def _get_comment(self, text): + return "// %s" % text + + def _declare_number_const(self, name, value): + type_ = self.type_mappings[self.type_aliases[real]] + return "const %s: %s = %s;" % (name, type_, value) + + def _format_code(self, lines): + return self.indent_code(lines) + + def _traverse_matrix_indices(self, mat): + rows, cols = mat.shape + return ((i, j) for i in range(rows) for j in range(cols)) + + def _get_loop_opening_ending(self, indices): + open_lines = [] + close_lines = [] + loopstart = "for %(var)s in %(start)s..%(end)s {" + for i in indices: + # Rust arrays start at 0 and end at dimension-1 + open_lines.append(loopstart % { + 'var': self._print(i), + 'start': self._print(i.lower), + 'end': self._print(i.upper + 1)}) + close_lines.append("}") + return open_lines, close_lines + + def _print_caller_var(self, expr): + if len(expr.args) > 1: + # for something like `sin(x + y + z)`, + # make sure we can get '(x + y + z).sin()' + # instead of 'x + y + z.sin()' + return '(' + self._print(expr) + ')' + elif expr.is_number: + return self._print(expr, _type=True) + else: + return self._print(expr) + + def _print_Function(self, expr): + """ + basic function for printing `Function` + + Function Style : + + 1. args[0].func(args[1:]), method with arguments + 2. args[0].func(), method without arguments + 3. args[1].func(), method without arguments (e.g. (e, x) => x.exp()) + 4. func(args), function with arguments + """ + + if expr.func.__name__ in self.known_functions: + cond_func = self.known_functions[expr.func.__name__] + func = None + style = 1 + if isinstance(cond_func, str): + func = cond_func + else: + for cond, func, style in cond_func: + if cond(*expr.args): + break + if func is not None: + if style == 1: + ret = "%(var)s.%(method)s(%(args)s)" % { + 'var': self._print_caller_var(expr.args[0]), + 'method': func, + 'args': self.stringify(expr.args[1:], ", ") if len(expr.args) > 1 else '' + } + elif style == 2: + ret = "%(var)s.%(method)s()" % { + 'var': self._print_caller_var(expr.args[0]), + 'method': func, + } + elif style == 3: + ret = "%(var)s.%(method)s()" % { + 'var': self._print_caller_var(expr.args[1]), + 'method': func, + } + else: + ret = "%(func)s(%(args)s)" % { + 'func': func, + 'args': self.stringify(expr.args, ", "), + } + return ret + elif hasattr(expr, '_imp_') and isinstance(expr._imp_, Lambda): + # inlined function + return self._print(expr._imp_(*expr.args)) + else: + return self._print_not_supported(expr) + + def _print_Mul(self, expr): + contains_floats = any(arg.is_real and not arg.is_integer for arg in expr.args) + if contains_floats: + expr = reduce(operator.mul,(self._cast_to_float(arg) if arg != -1 else arg for arg in expr.args)) + + return super()._print_Mul(expr) + + def _print_Add(self, expr, order=None): + contains_floats = any(arg.is_real and not arg.is_integer for arg in expr.args) + if contains_floats: + expr = reduce(operator.add, (self._cast_to_float(arg) for arg in expr.args)) + + return super()._print_Add(expr, order) + + def _print_Pow(self, expr): + if expr.base.is_integer and not expr.exp.is_integer: + expr = type(expr)(Float(expr.base), expr.exp) + return self._print(expr) + return self._print_Function(expr) + + def _print_TypeCast(self, expr): + if not expr.explicit: + return self._print(expr.expr) + else: + return self._print(expr.expr) + ' as %s' % self.type_mappings[self.type_aliases[expr.type_]] + + def _print_Float(self, expr, _type=False): + ret = super()._print_Float(expr) + if _type: + return ret + '_%s' % self.type_mappings[self.type_aliases[real]] + else: + return ret + + def _print_Integer(self, expr, _type=False): + ret = super()._print_Integer(expr) + if _type: + return ret + '_%s' % self.type_mappings[self.type_aliases[integer]] + else: + return ret + + def _print_Rational(self, expr): + p, q = int(expr.p), int(expr.q) + float_suffix = self.type_mappings[self.type_aliases[real]] + return '%d_%s/%d.0' % (p, float_suffix, q) + + def _print_Relational(self, expr): + if (expr.lhs.is_integer and not expr.rhs.is_integer) or (expr.rhs.is_integer and not expr.lhs.is_integer): + lhs = self._cast_to_float(expr.lhs) + rhs = self._cast_to_float(expr.rhs) + else: + lhs = expr.lhs + rhs = expr.rhs + lhs_code = self._print(lhs) + rhs_code = self._print(rhs) + op = expr.rel_op + return "{} {} {}".format(lhs_code, op, rhs_code) + + def _print_Indexed(self, expr): + # calculate index for 1d array + dims = expr.shape + elem = S.Zero + offset = S.One + for i in reversed(range(expr.rank)): + elem += expr.indices[i]*offset + offset *= dims[i] + return "%s[%s]" % (self._print(expr.base.label), self._print(elem)) + + def _print_Idx(self, expr): + return expr.label.name + + def _print_Dummy(self, expr): + return expr.name + + def _print_Exp1(self, expr, _type=False): + return "E" + + def _print_Pi(self, expr, _type=False): + return 'PI' + + def _print_Infinity(self, expr, _type=False): + return 'INFINITY' + + def _print_NegativeInfinity(self, expr, _type=False): + return 'NEG_INFINITY' + + def _print_BooleanTrue(self, expr, _type=False): + return "true" + + def _print_BooleanFalse(self, expr, _type=False): + return "false" + + def _print_bool(self, expr, _type=False): + return str(expr).lower() + + def _print_NaN(self, expr, _type=False): + return "NAN" + + def _print_Piecewise(self, expr): + if expr.args[-1].cond != True: + # We need the last conditional to be a True, otherwise the resulting + # function may not return a result. + raise ValueError("All Piecewise expressions must contain an " + "(expr, True) statement to be used as a default " + "condition. Without one, the generated " + "expression may not evaluate to anything under " + "some condition.") + lines = [] + + for i, (e, c) in enumerate(expr.args): + if i == 0: + lines.append("if (%s) {" % self._print(c)) + elif i == len(expr.args) - 1 and c == True: + lines[-1] += " else {" + else: + lines[-1] += " else if (%s) {" % self._print(c) + code0 = self._print(e) + lines.append(code0) + lines.append("}") + + if self._settings['inline']: + return " ".join(lines) + else: + return "\n".join(lines) + + def _print_ITE(self, expr): + from sympy.functions import Piecewise + return self._print(expr.rewrite(Piecewise, deep=False)) + + def _print_MatrixBase(self, A): + if A.cols == 1: + return "[%s]" % ", ".join(self._print(a) for a in A) + else: + raise ValueError("Full Matrix Support in Rust need Crates (https://crates.io/keywords/matrix).") + + def _print_SparseRepMatrix(self, mat): + # do not allow sparse matrices to be made dense + return self._print_not_supported(mat) + + def _print_MatrixElement(self, expr): + return "%s[%s]" % (expr.parent, + expr.j + expr.i*expr.parent.shape[1]) + + def _print_Symbol(self, expr): + + name = super()._print_Symbol(expr) + + if expr in self._dereference: + return '(*%s)' % name + else: + return name + + def _print_Assignment(self, expr): + from sympy.tensor.indexed import IndexedBase + lhs = expr.lhs + rhs = expr.rhs + if self._settings["contract"] and (lhs.has(IndexedBase) or + rhs.has(IndexedBase)): + # Here we check if there is looping to be done, and if so + # print the required loops. + return self._doprint_loops(rhs, lhs) + else: + lhs_code = self._print(lhs) + rhs_code = self._print(rhs) + return self._get_statement("%s = %s" % (lhs_code, rhs_code)) + + def _print_sign(self, expr): + arg = self._print(expr.args[0]) + return "(if (%s == 0.0) { 0.0 } else { (%s).signum() })" % (arg, arg) + + def _cast_to_float(self, expr): + if not expr.is_number: + return TypeCast(expr, real) + elif expr.is_integer: + return Float(expr) + return expr + + def _can_print(self, name): + """ Check if function ``name`` is either a known function or has its own + printing method. Used to check if rewriting is possible.""" + + # since the whole point of function_overrides is to enable proper printing, + # we presume they all are printable + + return name in self.known_functions or name in function_overrides or getattr(self, '_print_{}'.format(name), False) + + def _collect_functions(self, expr): + functions = set() + if isinstance(expr, Expr): + if expr.is_Function: + functions.add(expr.func) + for arg in expr.args: + functions = functions.union(self._collect_functions(arg)) + return functions + + def _rewrite_known_functions(self, expr): + if not isinstance(expr, Expr): + return expr + + expression_functions = self._collect_functions(expr) + rewriteable_functions = { + name: (target_f, required_fs) + for name, (target_f, required_fs) in self._rewriteable_functions.items() + if self._can_print(target_f) + and all(self._can_print(f) for f in required_fs) + } + for func in expression_functions: + target_f, _ = rewriteable_functions.get(func.__name__, (None, None)) + if target_f: + expr = expr.rewrite(target_f) + return expr + + def indent_code(self, code): + """Accepts a string of code or a list of code lines""" + + if isinstance(code, str): + code_lines = self.indent_code(code.splitlines(True)) + return ''.join(code_lines) + + tab = " " + inc_token = ('{', '(', '{\n', '(\n') + dec_token = ('}', ')') + + code = [ line.lstrip(' \t') for line in code ] + + increase = [ int(any(map(line.endswith, inc_token))) for line in code ] + decrease = [ int(any(map(line.startswith, dec_token))) + for line in code ] + + pretty = [] + level = 0 + for n, line in enumerate(code): + if line in ('', '\n'): + pretty.append(line) + continue + level -= decrease[n] + pretty.append("%s%s" % (tab*level, line)) + level += increase[n] + return pretty diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/smtlib.py b/.venv/lib/python3.13/site-packages/sympy/printing/smtlib.py new file mode 100644 index 0000000000000000000000000000000000000000..8fa015c6198cb32837eb3a0d7fe9d61352da25ad --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/smtlib.py @@ -0,0 +1,583 @@ +import typing + +import sympy +from sympy.core import Add, Mul +from sympy.core import Symbol, Expr, Float, Rational, Integer, Basic +from sympy.core.function import UndefinedFunction, Function +from sympy.core.relational import Relational, Unequality, Equality, LessThan, GreaterThan, StrictLessThan, StrictGreaterThan +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import exp, log, Pow +from sympy.functions.elementary.hyperbolic import sinh, cosh, tanh +from sympy.functions.elementary.miscellaneous import Min, Max +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import sin, cos, tan, asin, acos, atan, atan2 +from sympy.logic.boolalg import And, Or, Xor, Implies, Boolean +from sympy.logic.boolalg import BooleanTrue, BooleanFalse, BooleanFunction, Not, ITE +from sympy.printing.printer import Printer +from sympy.sets import Interval +from mpmath.libmp.libmpf import prec_to_dps, to_str as mlib_to_str +from sympy.assumptions.assume import AppliedPredicate +from sympy.assumptions.relation.binrel import AppliedBinaryRelation +from sympy.assumptions.ask import Q +from sympy.assumptions.relation.equality import StrictGreaterThanPredicate, StrictLessThanPredicate, GreaterThanPredicate, LessThanPredicate, EqualityPredicate + + +class SMTLibPrinter(Printer): + printmethod = "_smtlib" + + # based on dReal, an automated reasoning tool for solving problems that can be encoded as first-order logic formulas over the real numbers. + # dReal's special strength is in handling problems that involve a wide range of nonlinear real functions. + _default_settings: dict = { + 'precision': None, + 'known_types': { + bool: 'Bool', + int: 'Int', + float: 'Real' + }, + 'known_constants': { + # pi: 'MY_VARIABLE_PI_DECLARED_ELSEWHERE', + }, + 'known_functions': { + Add: '+', + Mul: '*', + + Equality: '=', + LessThan: '<=', + GreaterThan: '>=', + StrictLessThan: '<', + StrictGreaterThan: '>', + + EqualityPredicate(): '=', + LessThanPredicate(): '<=', + GreaterThanPredicate(): '>=', + StrictLessThanPredicate(): '<', + StrictGreaterThanPredicate(): '>', + + exp: 'exp', + log: 'log', + Abs: 'abs', + sin: 'sin', + cos: 'cos', + tan: 'tan', + asin: 'arcsin', + acos: 'arccos', + atan: 'arctan', + atan2: 'arctan2', + sinh: 'sinh', + cosh: 'cosh', + tanh: 'tanh', + Min: 'min', + Max: 'max', + Pow: 'pow', + + And: 'and', + Or: 'or', + Xor: 'xor', + Not: 'not', + ITE: 'ite', + Implies: '=>', + } + } + + symbol_table: dict + + def __init__(self, settings: typing.Optional[dict] = None, + symbol_table=None): + settings = settings or {} + self.symbol_table = symbol_table or {} + Printer.__init__(self, settings) + self._precision = self._settings['precision'] + self._known_types = dict(self._settings['known_types']) + self._known_constants = dict(self._settings['known_constants']) + self._known_functions = dict(self._settings['known_functions']) + + for _ in self._known_types.values(): assert self._is_legal_name(_) + for _ in self._known_constants.values(): assert self._is_legal_name(_) + # for _ in self._known_functions.values(): assert self._is_legal_name(_) # +, *, <, >, etc. + + def _is_legal_name(self, s: str): + if not s: return False + if s[0].isnumeric(): return False + return all(_.isalnum() or _ == '_' for _ in s) + + def _s_expr(self, op: str, args: typing.Union[list, tuple]) -> str: + args_str = ' '.join( + a if isinstance(a, str) + else self._print(a) + for a in args + ) + return f'({op} {args_str})' + + def _print_Function(self, e): + if e in self._known_functions: + op = self._known_functions[e] + elif type(e) in self._known_functions: + op = self._known_functions[type(e)] + elif type(type(e)) == UndefinedFunction: + op = e.name + elif isinstance(e, AppliedBinaryRelation) and e.function in self._known_functions: + op = self._known_functions[e.function] + return self._s_expr(op, e.arguments) + else: + op = self._known_functions[e] # throw KeyError + + return self._s_expr(op, e.args) + + def _print_Relational(self, e: Relational): + return self._print_Function(e) + + def _print_BooleanFunction(self, e: BooleanFunction): + return self._print_Function(e) + + def _print_Expr(self, e: Expr): + return self._print_Function(e) + + def _print_Unequality(self, e: Unequality): + if type(e) in self._known_functions: + return self._print_Relational(e) # default + else: + eq_op = self._known_functions[Equality] + not_op = self._known_functions[Not] + return self._s_expr(not_op, [self._s_expr(eq_op, e.args)]) + + def _print_Piecewise(self, e: Piecewise): + def _print_Piecewise_recursive(args: typing.Union[list, tuple]): + e, c = args[0] + if len(args) == 1: + assert (c is True) or isinstance(c, BooleanTrue) + return self._print(e) + else: + ite = self._known_functions[ITE] + return self._s_expr(ite, [ + c, e, _print_Piecewise_recursive(args[1:]) + ]) + + return _print_Piecewise_recursive(e.args) + + def _print_Interval(self, e: Interval): + if e.start.is_infinite and e.end.is_infinite: + return '' + elif e.start.is_infinite != e.end.is_infinite: + raise ValueError(f'One-sided intervals (`{e}`) are not supported in SMT.') + else: + return f'[{e.start}, {e.end}]' + + def _print_AppliedPredicate(self, e: AppliedPredicate): + if e.function == Q.positive: + rel = Q.gt(e.arguments[0],0) + elif e.function == Q.negative: + rel = Q.lt(e.arguments[0], 0) + elif e.function == Q.zero: + rel = Q.eq(e.arguments[0], 0) + elif e.function == Q.nonpositive: + rel = Q.le(e.arguments[0], 0) + elif e.function == Q.nonnegative: + rel = Q.ge(e.arguments[0], 0) + elif e.function == Q.nonzero: + rel = Q.ne(e.arguments[0], 0) + else: + raise ValueError(f"Predicate (`{e}`) is not handled.") + + return self._print_AppliedBinaryRelation(rel) + + def _print_AppliedBinaryRelation(self, e: AppliedPredicate): + if e.function == Q.ne: + return self._print_Unequality(Unequality(*e.arguments)) + else: + return self._print_Function(e) + + # todo: Sympy does not support quantifiers yet as of 2022, but quantifiers can be handy in SMT. + # For now, users can extend this class and build in their own quantifier support. + # See `test_quantifier_extensions()` in test_smtlib.py for an example of how this might look. + + # def _print_ForAll(self, e: ForAll): + # return self._s('forall', [ + # self._s('', [ + # self._s(sym.name, [self._type_name(sym), Interval(start, end)]) + # for sym, start, end in e.limits + # ]), + # e.function + # ]) + + def _print_BooleanTrue(self, x: BooleanTrue): + return 'true' + + def _print_BooleanFalse(self, x: BooleanFalse): + return 'false' + + def _print_Float(self, x: Float): + dps = prec_to_dps(x._prec) + str_real = mlib_to_str(x._mpf_, dps, strip_zeros=True, min_fixed=None, max_fixed=None) + + if 'e' in str_real: + (mant, exp) = str_real.split('e') + + if exp[0] == '+': + exp = exp[1:] + + mul = self._known_functions[Mul] + pow = self._known_functions[Pow] + + return r"(%s %s (%s 10 %s))" % (mul, mant, pow, exp) + elif str_real in ["+inf", "-inf"]: + raise ValueError("Infinite values are not supported in SMT.") + else: + return str_real + + def _print_float(self, x: float): + return self._print(Float(x)) + + def _print_Rational(self, x: Rational): + return self._s_expr('/', [x.p, x.q]) + + def _print_Integer(self, x: Integer): + assert x.q == 1 + return str(x.p) + + def _print_int(self, x: int): + return str(x) + + def _print_Symbol(self, x: Symbol): + assert self._is_legal_name(x.name) + return x.name + + def _print_NumberSymbol(self, x): + name = self._known_constants.get(x) + if name: + return name + else: + f = x.evalf(self._precision) if self._precision else x.evalf() + return self._print_Float(f) + + def _print_UndefinedFunction(self, x): + assert self._is_legal_name(x.name) + return x.name + + def _print_Exp1(self, x): + return ( + self._print_Function(exp(1, evaluate=False)) + if exp in self._known_functions else + self._print_NumberSymbol(x) + ) + + def emptyPrinter(self, expr): + raise NotImplementedError(f'Cannot convert `{repr(expr)}` of type `{type(expr)}` to SMT.') + + +def smtlib_code( + expr, + auto_assert=True, auto_declare=True, + precision=None, + symbol_table=None, + known_types=None, known_constants=None, known_functions=None, + prefix_expressions=None, suffix_expressions=None, + log_warn=None +): + r"""Converts ``expr`` to a string of smtlib code. + + Parameters + ========== + + expr : Expr | List[Expr] + A SymPy expression or system to be converted. + auto_assert : bool, optional + If false, do not modify expr and produce only the S-Expression equivalent of expr. + If true, assume expr is a system and assert each boolean element. + auto_declare : bool, optional + If false, do not produce declarations for the symbols used in expr. + If true, prepend all necessary declarations for variables used in expr based on symbol_table. + precision : integer, optional + The ``evalf(..)`` precision for numbers such as pi. + symbol_table : dict, optional + A dictionary where keys are ``Symbol`` or ``Function`` instances and values are their Python type i.e. ``bool``, ``int``, ``float``, or ``Callable[...]``. + If incomplete, an attempt will be made to infer types from ``expr``. + known_types: dict, optional + A dictionary where keys are ``bool``, ``int``, ``float`` etc. and values are their corresponding SMT type names. + If not given, a partial listing compatible with several solvers will be used. + known_functions : dict, optional + A dictionary where keys are ``Function``, ``Relational``, ``BooleanFunction``, or ``Expr`` instances and values are their SMT string representations. + If not given, a partial listing optimized for dReal solver (but compatible with others) will be used. + known_constants: dict, optional + A dictionary where keys are ``NumberSymbol`` instances and values are their SMT variable names. + When using this feature, extra caution must be taken to avoid naming collisions between user symbols and listed constants. + If not given, constants will be expanded inline i.e. ``3.14159`` instead of ``MY_SMT_VARIABLE_FOR_PI``. + prefix_expressions: list, optional + A list of lists of ``str`` and/or expressions to convert into SMTLib and prefix to the output. + suffix_expressions: list, optional + A list of lists of ``str`` and/or expressions to convert into SMTLib and postfix to the output. + log_warn: lambda function, optional + A function to record all warnings during potentially risky operations. + Soundness is a core value in SMT solving, so it is good to log all assumptions made. + + Examples + ======== + >>> from sympy import smtlib_code, symbols, sin, Eq + >>> x = symbols('x') + >>> smtlib_code(sin(x).series(x).removeO(), log_warn=print) + Could not infer type of `x`. Defaulting to float. + Non-Boolean expression `x**5/120 - x**3/6 + x` will not be asserted. Converting to SMTLib verbatim. + '(declare-const x Real)\n(+ x (* (/ -1 6) (pow x 3)) (* (/ 1 120) (pow x 5)))' + + >>> from sympy import Rational + >>> x, y, tau = symbols("x, y, tau") + >>> smtlib_code((2*tau)**Rational(7, 2), log_warn=print) + Could not infer type of `tau`. Defaulting to float. + Non-Boolean expression `8*sqrt(2)*tau**(7/2)` will not be asserted. Converting to SMTLib verbatim. + '(declare-const tau Real)\n(* 8 (pow 2 (/ 1 2)) (pow tau (/ 7 2)))' + + ``Piecewise`` expressions are implemented with ``ite`` expressions by default. + Note that if the ``Piecewise`` lacks a default term, represented by + ``(expr, True)`` then an error will be thrown. This is to prevent + generating an expression that may not evaluate to anything. + + >>> from sympy import Piecewise + >>> pw = Piecewise((x + 1, x > 0), (x, True)) + >>> smtlib_code(Eq(pw, 3), symbol_table={x: float}, log_warn=print) + '(declare-const x Real)\n(assert (= (ite (> x 0) (+ 1 x) x) 3))' + + Custom printing can be defined for certain types by passing a dictionary of + PythonType : "SMT Name" to the ``known_types``, ``known_constants``, and ``known_functions`` kwargs. + + >>> from typing import Callable + >>> from sympy import Function, Add + >>> f = Function('f') + >>> g = Function('g') + >>> smt_builtin_funcs = { # functions our SMT solver will understand + ... f: "existing_smtlib_fcn", + ... Add: "sum", + ... } + >>> user_def_funcs = { # functions defined by the user must have their types specified explicitly + ... g: Callable[[int], float], + ... } + >>> smtlib_code(f(x) + g(x), symbol_table=user_def_funcs, known_functions=smt_builtin_funcs, log_warn=print) + Non-Boolean expression `f(x) + g(x)` will not be asserted. Converting to SMTLib verbatim. + '(declare-const x Int)\n(declare-fun g (Int) Real)\n(sum (existing_smtlib_fcn x) (g x))' + """ + log_warn = log_warn or (lambda _: None) + + if not isinstance(expr, list): expr = [expr] + expr = [ + sympy.sympify(_, strict=True, evaluate=False, convert_xor=False) + for _ in expr + ] + + if not symbol_table: symbol_table = {} + symbol_table = _auto_infer_smtlib_types( + *expr, symbol_table=symbol_table + ) + # See [FALLBACK RULES] + # Need SMTLibPrinter to populate known_functions and known_constants first. + + settings = {} + if precision: settings['precision'] = precision + del precision + + if known_types: settings['known_types'] = known_types + del known_types + + if known_functions: settings['known_functions'] = known_functions + del known_functions + + if known_constants: settings['known_constants'] = known_constants + del known_constants + + if not prefix_expressions: prefix_expressions = [] + if not suffix_expressions: suffix_expressions = [] + + p = SMTLibPrinter(settings, symbol_table) + del symbol_table + + # [FALLBACK RULES] + for e in expr: + for sym in e.atoms(Symbol, Function): + if ( + sym.is_Symbol and + sym not in p._known_constants and + sym not in p.symbol_table + ): + log_warn(f"Could not infer type of `{sym}`. Defaulting to float.") + p.symbol_table[sym] = float + if ( + sym.is_Function and + type(sym) not in p._known_functions and + type(sym) not in p.symbol_table and + not sym.is_Piecewise + ): raise TypeError( + f"Unknown type of undefined function `{sym}`. " + f"Must be mapped to ``str`` in known_functions or mapped to ``Callable[..]`` in symbol_table." + ) + + declarations = [] + if auto_declare: + constants = {sym.name: sym for e in expr for sym in e.free_symbols + if sym not in p._known_constants} + functions = {fnc.name: fnc for e in expr for fnc in e.atoms(Function) + if type(fnc) not in p._known_functions and not fnc.is_Piecewise} + declarations = \ + [ + _auto_declare_smtlib(sym, p, log_warn) + for sym in constants.values() + ] + [ + _auto_declare_smtlib(fnc, p, log_warn) + for fnc in functions.values() + ] + declarations = [decl for decl in declarations if decl] + + if auto_assert: + expr = [_auto_assert_smtlib(e, p, log_warn) for e in expr] + + # return SMTLibPrinter().doprint(expr) + return '\n'.join([ + # ';; PREFIX EXPRESSIONS', + *[ + e if isinstance(e, str) else p.doprint(e) + for e in prefix_expressions + ], + + # ';; DECLARATIONS', + *sorted(e for e in declarations), + + # ';; EXPRESSIONS', + *[ + e if isinstance(e, str) else p.doprint(e) + for e in expr + ], + + # ';; SUFFIX EXPRESSIONS', + *[ + e if isinstance(e, str) else p.doprint(e) + for e in suffix_expressions + ], + ]) + + +def _auto_declare_smtlib(sym: typing.Union[Symbol, Function], p: SMTLibPrinter, log_warn: typing.Callable[[str], None]): + if sym.is_Symbol: + type_signature = p.symbol_table[sym] + assert isinstance(type_signature, type) + type_signature = p._known_types[type_signature] + return p._s_expr('declare-const', [sym, type_signature]) + + elif sym.is_Function: + type_signature = p.symbol_table[type(sym)] + assert callable(type_signature) + type_signature = [p._known_types[_] for _ in type_signature.__args__] + assert len(type_signature) > 0 + params_signature = f"({' '.join(type_signature[:-1])})" + return_signature = type_signature[-1] + return p._s_expr('declare-fun', [type(sym), params_signature, return_signature]) + + else: + log_warn(f"Non-Symbol/Function `{sym}` will not be declared.") + return None + + +def _auto_assert_smtlib(e: Expr, p: SMTLibPrinter, log_warn: typing.Callable[[str], None]): + if isinstance(e, Boolean) or ( + e in p.symbol_table and p.symbol_table[e] == bool + ) or ( + e.is_Function and + type(e) in p.symbol_table and + p.symbol_table[type(e)].__args__[-1] == bool + ): + return p._s_expr('assert', [e]) + else: + log_warn(f"Non-Boolean expression `{e}` will not be asserted. Converting to SMTLib verbatim.") + return e + + +def _auto_infer_smtlib_types( + *exprs: Basic, + symbol_table: typing.Optional[dict] = None +) -> dict: + # [TYPE INFERENCE RULES] + # X is alone in an expr => X is bool + # X in BooleanFunction.args => X is bool + # X matches to a bool param of a symbol_table function => X is bool + # X matches to an int param of a symbol_table function => X is int + # X.is_integer => X is int + # X == Y, where X is T => Y is T + + # [FALLBACK RULES] + # see _auto_declare_smtlib(..) + # X is not bool and X is not int and X is Symbol => X is float + # else (e.g. X is Function) => error. must be specified explicitly. + + _symbols = dict(symbol_table) if symbol_table else {} + + def safe_update(syms: set, inf): + for s in syms: + assert s.is_Symbol + if (old_type := _symbols.setdefault(s, inf)) != inf: + raise TypeError(f"Could not infer type of `{s}`. Apparently both `{old_type}` and `{inf}`?") + + # EXPLICIT TYPES + safe_update({ + e + for e in exprs + if e.is_Symbol + }, bool) + + safe_update({ + symbol + for e in exprs + for boolfunc in e.atoms(BooleanFunction) + for symbol in boolfunc.args + if symbol.is_Symbol + }, bool) + + safe_update({ + symbol + for e in exprs + for boolfunc in e.atoms(Function) + if type(boolfunc) in _symbols + for symbol, param in zip(boolfunc.args, _symbols[type(boolfunc)].__args__) + if symbol.is_Symbol and param == bool + }, bool) + + safe_update({ + symbol + for e in exprs + for intfunc in e.atoms(Function) + if type(intfunc) in _symbols + for symbol, param in zip(intfunc.args, _symbols[type(intfunc)].__args__) + if symbol.is_Symbol and param == int + }, int) + + safe_update({ + symbol + for e in exprs + for symbol in e.atoms(Symbol) + if symbol.is_integer + }, int) + + safe_update({ + symbol + for e in exprs + for symbol in e.atoms(Symbol) + if symbol.is_real and not symbol.is_integer + }, float) + + # EQUALITY RELATION RULE + rels_eq = [rel for expr in exprs for rel in expr.atoms(Equality)] + rels = [ + (rel.lhs, rel.rhs) for rel in rels_eq if rel.lhs.is_Symbol + ] + [ + (rel.rhs, rel.lhs) for rel in rels_eq if rel.rhs.is_Symbol + ] + for infer, reltd in rels: + inference = ( + _symbols[infer] if infer in _symbols else + _symbols[reltd] if reltd in _symbols else + + _symbols[type(reltd)].__args__[-1] + if reltd.is_Function and type(reltd) in _symbols else + + bool if reltd.is_Boolean else + int if reltd.is_integer or reltd.is_Integer else + float if reltd.is_real else + None + ) + if inference: safe_update({infer}, inference) + + return _symbols diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/tensorflow.py b/.venv/lib/python3.13/site-packages/sympy/printing/tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..78b0df62b611f336468769e4cee1695bc068eee9 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/tensorflow.py @@ -0,0 +1,224 @@ +import sympy.codegen +import sympy.codegen.cfunctions +from sympy.external.importtools import version_tuple +from collections.abc import Iterable + +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.codegen.cfunctions import Sqrt +from sympy.external import import_module +from sympy.printing.precedence import PRECEDENCE +from sympy.printing.pycode import AbstractPythonCodePrinter, ArrayPrinter +import sympy + +tensorflow = import_module('tensorflow') + +class TensorflowPrinter(ArrayPrinter, AbstractPythonCodePrinter): + """ + Tensorflow printer which handles vectorized piecewise functions, + logical operators, max/min, and relational operators. + """ + printmethod = "_tensorflowcode" + + mapping = { + sympy.Abs: "tensorflow.math.abs", + sympy.sign: "tensorflow.math.sign", + + # XXX May raise error for ints. + sympy.ceiling: "tensorflow.math.ceil", + sympy.floor: "tensorflow.math.floor", + sympy.log: "tensorflow.math.log", + sympy.exp: "tensorflow.math.exp", + Sqrt: "tensorflow.math.sqrt", + sympy.cos: "tensorflow.math.cos", + sympy.acos: "tensorflow.math.acos", + sympy.sin: "tensorflow.math.sin", + sympy.asin: "tensorflow.math.asin", + sympy.tan: "tensorflow.math.tan", + sympy.atan: "tensorflow.math.atan", + sympy.atan2: "tensorflow.math.atan2", + # XXX Also may give NaN for complex results. + sympy.cosh: "tensorflow.math.cosh", + sympy.acosh: "tensorflow.math.acosh", + sympy.sinh: "tensorflow.math.sinh", + sympy.asinh: "tensorflow.math.asinh", + sympy.tanh: "tensorflow.math.tanh", + sympy.atanh: "tensorflow.math.atanh", + + sympy.re: "tensorflow.math.real", + sympy.im: "tensorflow.math.imag", + sympy.arg: "tensorflow.math.angle", + + # XXX May raise error for ints and complexes + sympy.erf: "tensorflow.math.erf", + sympy.loggamma: "tensorflow.math.lgamma", + + sympy.Eq: "tensorflow.math.equal", + sympy.Ne: "tensorflow.math.not_equal", + sympy.StrictGreaterThan: "tensorflow.math.greater", + sympy.StrictLessThan: "tensorflow.math.less", + sympy.LessThan: "tensorflow.math.less_equal", + sympy.GreaterThan: "tensorflow.math.greater_equal", + + sympy.And: "tensorflow.math.logical_and", + sympy.Or: "tensorflow.math.logical_or", + sympy.Not: "tensorflow.math.logical_not", + sympy.Max: "tensorflow.math.maximum", + sympy.Min: "tensorflow.math.minimum", + + # Matrices + sympy.MatAdd: "tensorflow.math.add", + sympy.HadamardProduct: "tensorflow.math.multiply", + sympy.Trace: "tensorflow.linalg.trace", + + # XXX May raise error for integer matrices. + sympy.Determinant : "tensorflow.linalg.det", + } + + _default_settings = dict( + AbstractPythonCodePrinter._default_settings, + tensorflow_version=None + ) + + def __init__(self, settings=None): + super().__init__(settings) + + version = self._settings['tensorflow_version'] + if version is None and tensorflow: + version = tensorflow.__version__ + self.tensorflow_version = version + + def _print_Function(self, expr): + op = self.mapping.get(type(expr), None) + if op is None: + return super()._print_Basic(expr) + children = [self._print(arg) for arg in expr.args] + if len(children) == 1: + return "%s(%s)" % ( + self._module_format(op), + children[0] + ) + else: + return self._expand_fold_binary_op(op, children) + + _print_Expr = _print_Function + _print_Application = _print_Function + _print_MatrixExpr = _print_Function + # TODO: a better class structure would avoid this mess: + _print_Relational = _print_Function + _print_Not = _print_Function + _print_And = _print_Function + _print_Or = _print_Function + _print_HadamardProduct = _print_Function + _print_Trace = _print_Function + _print_Determinant = _print_Function + + def _print_Inverse(self, expr): + op = self._module_format('tensorflow.linalg.inv') + return "{}({})".format(op, self._print(expr.arg)) + + def _print_Transpose(self, expr): + version = self.tensorflow_version + if version and version_tuple(version) < version_tuple('1.14'): + op = self._module_format('tensorflow.matrix_transpose') + else: + op = self._module_format('tensorflow.linalg.matrix_transpose') + return "{}({})".format(op, self._print(expr.arg)) + + def _print_Derivative(self, expr): + variables = expr.variables + if any(isinstance(i, Iterable) for i in variables): + raise NotImplementedError("derivation by multiple variables is not supported") + def unfold(expr, args): + if not args: + return self._print(expr) + return "%s(%s, %s)[0]" % ( + self._module_format("tensorflow.gradients"), + unfold(expr, args[:-1]), + self._print(args[-1]), + ) + return unfold(expr.expr, variables) + + def _print_Piecewise(self, expr): + version = self.tensorflow_version + if version and version_tuple(version) < version_tuple('1.0'): + tensorflow_piecewise = "tensorflow.select" + else: + tensorflow_piecewise = "tensorflow.where" + + from sympy.functions.elementary.piecewise import Piecewise + e, cond = expr.args[0].args + if len(expr.args) == 1: + return '{}({}, {}, {})'.format( + self._module_format(tensorflow_piecewise), + self._print(cond), + self._print(e), + 0) + + return '{}({}, {}, {})'.format( + self._module_format(tensorflow_piecewise), + self._print(cond), + self._print(e), + self._print(Piecewise(*expr.args[1:]))) + + def _print_Pow(self, expr): + # XXX May raise error for + # int**float or int**complex or float**complex + base, exp = expr.args + if expr.exp == S.Half: + return "{}({})".format( + self._module_format("tensorflow.math.sqrt"), self._print(base)) + return "{}({}, {})".format( + self._module_format("tensorflow.math.pow"), + self._print(base), self._print(exp)) + + def _print_MatrixBase(self, expr): + tensorflow_f = "tensorflow.Variable" if expr.free_symbols else "tensorflow.constant" + data = "["+", ".join(["["+", ".join([self._print(j) for j in i])+"]" for i in expr.tolist()])+"]" + return "%s(%s)" % ( + self._module_format(tensorflow_f), + data, + ) + + def _print_MatMul(self, expr): + from sympy.matrices.expressions import MatrixExpr + mat_args = [arg for arg in expr.args if isinstance(arg, MatrixExpr)] + args = [arg for arg in expr.args if arg not in mat_args] + if args: + return "%s*%s" % ( + self.parenthesize(Mul.fromiter(args), PRECEDENCE["Mul"]), + self._expand_fold_binary_op( + "tensorflow.linalg.matmul", mat_args) + ) + else: + return self._expand_fold_binary_op( + "tensorflow.linalg.matmul", mat_args) + + def _print_MatPow(self, expr): + return self._expand_fold_binary_op( + "tensorflow.linalg.matmul", [expr.base]*expr.exp) + + def _print_CodeBlock(self, expr): + # TODO: is this necessary? + ret = [] + for subexpr in expr.args: + ret.append(self._print(subexpr)) + return "\n".join(ret) + + def _print_isnan(self, exp): + return f'tensorflow.math.is_nan({self._print(*exp.args)})' + + def _print_isinf(self, exp): + return f'tensorflow.math.is_inf({self._print(*exp.args)})' + + _module = "tensorflow" + _einsum = "linalg.einsum" + _add = "math.add" + _transpose = "transpose" + _ones = "ones" + _zeros = "zeros" + + +def tensorflow_code(expr, **settings): + printer = TensorflowPrinter(settings) + return printer.doprint(expr) diff --git a/.venv/lib/python3.13/site-packages/sympy/printing/theanocode.py b/.venv/lib/python3.13/site-packages/sympy/printing/theanocode.py new file mode 100644 index 0000000000000000000000000000000000000000..dce908865d426dabede2b6749ad944e5a420e4cf --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/printing/theanocode.py @@ -0,0 +1,571 @@ +""" +.. deprecated:: 1.8 + + ``sympy.printing.theanocode`` is deprecated. Theano has been renamed to + Aesara. Use ``sympy.printing.aesaracode`` instead. See + :ref:`theanocode-deprecated` for more information. + +""" +from __future__ import annotations +import math +from typing import Any + +from sympy.external import import_module +from sympy.printing.printer import Printer +from sympy.utilities.iterables import is_sequence +import sympy +from functools import partial + +from sympy.utilities.decorator import doctest_depends_on +from sympy.utilities.exceptions import sympy_deprecation_warning + + +__doctest_requires__ = {('theano_function',): ['theano']} + + +theano = import_module('theano') + + +if theano: + ts = theano.scalar + tt = theano.tensor + from theano.sandbox import linalg as tlinalg + + mapping = { + sympy.Add: tt.add, + sympy.Mul: tt.mul, + sympy.Abs: tt.abs_, + sympy.sign: tt.sgn, + sympy.ceiling: tt.ceil, + sympy.floor: tt.floor, + sympy.log: tt.log, + sympy.exp: tt.exp, + sympy.sqrt: tt.sqrt, + sympy.cos: tt.cos, + sympy.acos: tt.arccos, + sympy.sin: tt.sin, + sympy.asin: tt.arcsin, + sympy.tan: tt.tan, + sympy.atan: tt.arctan, + sympy.atan2: tt.arctan2, + sympy.cosh: tt.cosh, + sympy.acosh: tt.arccosh, + sympy.sinh: tt.sinh, + sympy.asinh: tt.arcsinh, + sympy.tanh: tt.tanh, + sympy.atanh: tt.arctanh, + sympy.re: tt.real, + sympy.im: tt.imag, + sympy.arg: tt.angle, + sympy.erf: tt.erf, + sympy.gamma: tt.gamma, + sympy.loggamma: tt.gammaln, + sympy.Pow: tt.pow, + sympy.Eq: tt.eq, + sympy.StrictGreaterThan: tt.gt, + sympy.StrictLessThan: tt.lt, + sympy.LessThan: tt.le, + sympy.GreaterThan: tt.ge, + sympy.And: tt.and_, + sympy.Or: tt.or_, + sympy.Max: tt.maximum, # SymPy accept >2 inputs, Theano only 2 + sympy.Min: tt.minimum, # SymPy accept >2 inputs, Theano only 2 + sympy.conjugate: tt.conj, + sympy.core.numbers.ImaginaryUnit: lambda:tt.complex(0,1), + # Matrices + sympy.MatAdd: tt.Elemwise(ts.add), + sympy.HadamardProduct: tt.Elemwise(ts.mul), + sympy.Trace: tlinalg.trace, + sympy.Determinant : tlinalg.det, + sympy.Inverse: tlinalg.matrix_inverse, + sympy.Transpose: tt.DimShuffle((False, False), [1, 0]), + } + + +class TheanoPrinter(Printer): + """ Code printer which creates Theano symbolic expression graphs. + + Parameters + ========== + + cache : dict + Cache dictionary to use. If None (default) will use + the global cache. To create a printer which does not depend on or alter + global state pass an empty dictionary. Note: the dictionary is not + copied on initialization of the printer and will be updated in-place, + so using the same dict object when creating multiple printers or making + multiple calls to :func:`.theano_code` or :func:`.theano_function` means + the cache is shared between all these applications. + + Attributes + ========== + + cache : dict + A cache of Theano variables which have been created for SymPy + symbol-like objects (e.g. :class:`sympy.core.symbol.Symbol` or + :class:`sympy.matrices.expressions.MatrixSymbol`). This is used to + ensure that all references to a given symbol in an expression (or + multiple expressions) are printed as the same Theano variable, which is + created only once. Symbols are differentiated only by name and type. The + format of the cache's contents should be considered opaque to the user. + """ + printmethod = "_theano" + + def __init__(self, *args, **kwargs): + self.cache = kwargs.pop('cache', {}) + super().__init__(*args, **kwargs) + + def _get_key(self, s, name=None, dtype=None, broadcastable=None): + """ Get the cache key for a SymPy object. + + Parameters + ========== + + s : sympy.core.basic.Basic + SymPy object to get key for. + + name : str + Name of object, if it does not have a ``name`` attribute. + """ + + if name is None: + name = s.name + + return (name, type(s), s.args, dtype, broadcastable) + + def _get_or_create(self, s, name=None, dtype=None, broadcastable=None): + """ + Get the Theano variable for a SymPy symbol from the cache, or create it + if it does not exist. + """ + + # Defaults + if name is None: + name = s.name + if dtype is None: + dtype = 'floatX' + if broadcastable is None: + broadcastable = () + + key = self._get_key(s, name, dtype=dtype, broadcastable=broadcastable) + + if key in self.cache: + return self.cache[key] + + value = tt.tensor(name=name, dtype=dtype, broadcastable=broadcastable) + self.cache[key] = value + return value + + def _print_Symbol(self, s, **kwargs): + dtype = kwargs.get('dtypes', {}).get(s) + bc = kwargs.get('broadcastables', {}).get(s) + return self._get_or_create(s, dtype=dtype, broadcastable=bc) + + def _print_AppliedUndef(self, s, **kwargs): + name = str(type(s)) + '_' + str(s.args[0]) + dtype = kwargs.get('dtypes', {}).get(s) + bc = kwargs.get('broadcastables', {}).get(s) + return self._get_or_create(s, name=name, dtype=dtype, broadcastable=bc) + + def _print_Basic(self, expr, **kwargs): + op = mapping[type(expr)] + children = [self._print(arg, **kwargs) for arg in expr.args] + return op(*children) + + def _print_Number(self, n, **kwargs): + # Integers already taken care of below, interpret as float + return float(n.evalf()) + + def _print_MatrixSymbol(self, X, **kwargs): + dtype = kwargs.get('dtypes', {}).get(X) + return self._get_or_create(X, dtype=dtype, broadcastable=(None, None)) + + def _print_DenseMatrix(self, X, **kwargs): + if not hasattr(tt, 'stacklists'): + raise NotImplementedError( + "Matrix translation not yet supported in this version of Theano") + + return tt.stacklists([ + [self._print(arg, **kwargs) for arg in L] + for L in X.tolist() + ]) + + _print_ImmutableMatrix = _print_ImmutableDenseMatrix = _print_DenseMatrix + + def _print_MatMul(self, expr, **kwargs): + children = [self._print(arg, **kwargs) for arg in expr.args] + result = children[0] + for child in children[1:]: + result = tt.dot(result, child) + return result + + def _print_MatPow(self, expr, **kwargs): + children = [self._print(arg, **kwargs) for arg in expr.args] + result = 1 + if isinstance(children[1], int) and children[1] > 0: + for i in range(children[1]): + result = tt.dot(result, children[0]) + else: + raise NotImplementedError('''Only non-negative integer + powers of matrices can be handled by Theano at the moment''') + return result + + def _print_MatrixSlice(self, expr, **kwargs): + parent = self._print(expr.parent, **kwargs) + rowslice = self._print(slice(*expr.rowslice), **kwargs) + colslice = self._print(slice(*expr.colslice), **kwargs) + return parent[rowslice, colslice] + + def _print_BlockMatrix(self, expr, **kwargs): + nrows, ncols = expr.blocks.shape + blocks = [[self._print(expr.blocks[r, c], **kwargs) + for c in range(ncols)] + for r in range(nrows)] + return tt.join(0, *[tt.join(1, *row) for row in blocks]) + + + def _print_slice(self, expr, **kwargs): + return slice(*[self._print(i, **kwargs) + if isinstance(i, sympy.Basic) else i + for i in (expr.start, expr.stop, expr.step)]) + + def _print_Pi(self, expr, **kwargs): + return math.pi + + def _print_Exp1(self, expr, **kwargs): + return ts.exp(1) + + def _print_Piecewise(self, expr, **kwargs): + import numpy as np + e, cond = expr.args[0].args # First condition and corresponding value + + # Print conditional expression and value for first condition + p_cond = self._print(cond, **kwargs) + p_e = self._print(e, **kwargs) + + # One condition only + if len(expr.args) == 1: + # Return value if condition else NaN + return tt.switch(p_cond, p_e, np.nan) + + # Return value_1 if condition_1 else evaluate remaining conditions + p_remaining = self._print(sympy.Piecewise(*expr.args[1:]), **kwargs) + return tt.switch(p_cond, p_e, p_remaining) + + def _print_Rational(self, expr, **kwargs): + return tt.true_div(self._print(expr.p, **kwargs), + self._print(expr.q, **kwargs)) + + def _print_Integer(self, expr, **kwargs): + return expr.p + + def _print_factorial(self, expr, **kwargs): + return self._print(sympy.gamma(expr.args[0] + 1), **kwargs) + + def _print_Derivative(self, deriv, **kwargs): + rv = self._print(deriv.expr, **kwargs) + for var in deriv.variables: + var = self._print(var, **kwargs) + rv = tt.Rop(rv, var, tt.ones_like(var)) + return rv + + def emptyPrinter(self, expr): + return expr + + def doprint(self, expr, dtypes=None, broadcastables=None): + """ Convert a SymPy expression to a Theano graph variable. + + The ``dtypes`` and ``broadcastables`` arguments are used to specify the + data type, dimension, and broadcasting behavior of the Theano variables + corresponding to the free symbols in ``expr``. Each is a mapping from + SymPy symbols to the value of the corresponding argument to + ``theano.tensor.Tensor``. + + See the corresponding `documentation page`__ for more information on + broadcasting in Theano. + + .. __: http://deeplearning.net/software/theano/tutorial/broadcasting.html + + Parameters + ========== + + expr : sympy.core.expr.Expr + SymPy expression to print. + + dtypes : dict + Mapping from SymPy symbols to Theano datatypes to use when creating + new Theano variables for those symbols. Corresponds to the ``dtype`` + argument to ``theano.tensor.Tensor``. Defaults to ``'floatX'`` + for symbols not included in the mapping. + + broadcastables : dict + Mapping from SymPy symbols to the value of the ``broadcastable`` + argument to ``theano.tensor.Tensor`` to use when creating Theano + variables for those symbols. Defaults to the empty tuple for symbols + not included in the mapping (resulting in a scalar). + + Returns + ======= + + theano.gof.graph.Variable + A variable corresponding to the expression's value in a Theano + symbolic expression graph. + + """ + if dtypes is None: + dtypes = {} + if broadcastables is None: + broadcastables = {} + + return self._print(expr, dtypes=dtypes, broadcastables=broadcastables) + + +global_cache: dict[Any, Any] = {} + + +def theano_code(expr, cache=None, **kwargs): + """ + Convert a SymPy expression into a Theano graph variable. + + .. deprecated:: 1.8 + + ``sympy.printing.theanocode`` is deprecated. Theano has been renamed to + Aesara. Use ``sympy.printing.aesaracode`` instead. See + :ref:`theanocode-deprecated` for more information. + + Parameters + ========== + + expr : sympy.core.expr.Expr + SymPy expression object to convert. + + cache : dict + Cached Theano variables (see :class:`TheanoPrinter.cache + `). Defaults to the module-level global cache. + + dtypes : dict + Passed to :meth:`.TheanoPrinter.doprint`. + + broadcastables : dict + Passed to :meth:`.TheanoPrinter.doprint`. + + Returns + ======= + + theano.gof.graph.Variable + A variable corresponding to the expression's value in a Theano symbolic + expression graph. + + """ + sympy_deprecation_warning( + """ + sympy.printing.theanocode is deprecated. Theano has been renamed to + Aesara. Use sympy.printing.aesaracode instead.""", + deprecated_since_version="1.8", + active_deprecations_target='theanocode-deprecated') + + if not theano: + raise ImportError("theano is required for theano_code") + + if cache is None: + cache = global_cache + + return TheanoPrinter(cache=cache, settings={}).doprint(expr, **kwargs) + + +def dim_handling(inputs, dim=None, dims=None, broadcastables=None): + r""" + Get value of ``broadcastables`` argument to :func:`.theano_code` from + keyword arguments to :func:`.theano_function`. + + Included for backwards compatibility. + + Parameters + ========== + + inputs + Sequence of input symbols. + + dim : int + Common number of dimensions for all inputs. Overrides other arguments + if given. + + dims : dict + Mapping from input symbols to number of dimensions. Overrides + ``broadcastables`` argument if given. + + broadcastables : dict + Explicit value of ``broadcastables`` argument to + :meth:`.TheanoPrinter.doprint`. If not None function will return this value unchanged. + + Returns + ======= + dict + Dictionary mapping elements of ``inputs`` to their "broadcastable" + values (tuple of ``bool``\ s). + """ + if dim is not None: + return dict.fromkeys(inputs, (False,) * dim) + + if dims is not None: + maxdim = max(dims.values()) + return { + s: (False,) * d + (True,) * (maxdim - d) + for s, d in dims.items() + } + + if broadcastables is not None: + return broadcastables + + return {} + + +@doctest_depends_on(modules=('theano',)) +def theano_function(inputs, outputs, scalar=False, *, + dim=None, dims=None, broadcastables=None, **kwargs): + """ + Create a Theano function from SymPy expressions. + + .. deprecated:: 1.8 + + ``sympy.printing.theanocode`` is deprecated. Theano has been renamed to + Aesara. Use ``sympy.printing.aesaracode`` instead. See + :ref:`theanocode-deprecated` for more information. + + The inputs and outputs are converted to Theano variables using + :func:`.theano_code` and then passed to ``theano.function``. + + Parameters + ========== + + inputs + Sequence of symbols which constitute the inputs of the function. + + outputs + Sequence of expressions which constitute the outputs(s) of the + function. The free symbols of each expression must be a subset of + ``inputs``. + + scalar : bool + Convert 0-dimensional arrays in output to scalars. This will return a + Python wrapper function around the Theano function object. + + cache : dict + Cached Theano variables (see :class:`TheanoPrinter.cache + `). Defaults to the module-level global cache. + + dtypes : dict + Passed to :meth:`.TheanoPrinter.doprint`. + + broadcastables : dict + Passed to :meth:`.TheanoPrinter.doprint`. + + dims : dict + Alternative to ``broadcastables`` argument. Mapping from elements of + ``inputs`` to integers indicating the dimension of their associated + arrays/tensors. Overrides ``broadcastables`` argument if given. + + dim : int + Another alternative to the ``broadcastables`` argument. Common number of + dimensions to use for all arrays/tensors. + ``theano_function([x, y], [...], dim=2)`` is equivalent to using + ``broadcastables={x: (False, False), y: (False, False)}``. + + Returns + ======= + callable + A callable object which takes values of ``inputs`` as positional + arguments and returns an output array for each of the expressions + in ``outputs``. If ``outputs`` is a single expression the function will + return a Numpy array, if it is a list of multiple expressions the + function will return a list of arrays. See description of the ``squeeze`` + argument above for the behavior when a single output is passed in a list. + The returned object will either be an instance of + ``theano.compile.function_module.Function`` or a Python wrapper + function around one. In both cases, the returned value will have a + ``theano_function`` attribute which points to the return value of + ``theano.function``. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy.printing.theanocode import theano_function + + A simple function with one input and one output: + + >>> f1 = theano_function([x], [x**2 - 1], scalar=True) + >>> f1(3) + 8.0 + + A function with multiple inputs and one output: + + >>> f2 = theano_function([x, y, z], [(x**z + y**z)**(1/z)], scalar=True) + >>> f2(3, 4, 2) + 5.0 + + A function with multiple inputs and multiple outputs: + + >>> f3 = theano_function([x, y], [x**2 + y**2, x**2 - y**2], scalar=True) + >>> f3(2, 3) + [13.0, -5.0] + + See also + ======== + + dim_handling + + """ + sympy_deprecation_warning( + """ + sympy.printing.theanocode is deprecated. Theano has been renamed to Aesara. Use sympy.printing.aesaracode instead""", + deprecated_since_version="1.8", + active_deprecations_target='theanocode-deprecated') + + if not theano: + raise ImportError("theano is required for theano_function") + + # Pop off non-theano keyword args + cache = kwargs.pop('cache', {}) + dtypes = kwargs.pop('dtypes', {}) + + broadcastables = dim_handling( + inputs, dim=dim, dims=dims, broadcastables=broadcastables, + ) + + # Print inputs/outputs + code = partial(theano_code, cache=cache, dtypes=dtypes, + broadcastables=broadcastables) + tinputs = list(map(code, inputs)) + toutputs = list(map(code, outputs)) + + #fix constant expressions as variables + toutputs = [output if isinstance(output, theano.Variable) else tt.as_tensor_variable(output) for output in toutputs] + + if len(toutputs) == 1: + toutputs = toutputs[0] + + # Compile theano func + func = theano.function(tinputs, toutputs, **kwargs) + + is_0d = [len(o.variable.broadcastable) == 0 for o in func.outputs] + + # No wrapper required + if not scalar or not any(is_0d): + func.theano_function = func + return func + + # Create wrapper to convert 0-dimensional outputs to scalars + def wrapper(*args): + out = func(*args) + # out can be array(1.0) or [array(1.0), array(2.0)] + + if is_sequence(out): + return [o[()] if is_0d[i] else o for i, o in enumerate(out)] + else: + return out[()] + + wrapper.__wrapped__ = func + wrapper.__doc__ = func.__doc__ + wrapper.theano_function = func + return wrapper diff --git a/.venv/lib/python3.13/site-packages/sympy/sandbox/__init__.py b/.venv/lib/python3.13/site-packages/sympy/sandbox/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3a84b7517819bb2fc9886274e09d955a74cabca1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/sandbox/__init__.py @@ -0,0 +1,8 @@ +""" +Sandbox module of SymPy. + +This module contains experimental code, use at your own risk! + +There is no warranty that this code will still be located here in future +versions of SymPy. +""" diff --git a/.venv/lib/python3.13/site-packages/sympy/sandbox/indexed_integrals.py b/.venv/lib/python3.13/site-packages/sympy/sandbox/indexed_integrals.py new file mode 100644 index 0000000000000000000000000000000000000000..c0c17d141448b5a71cb814bff76a710a5bd43f88 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/sandbox/indexed_integrals.py @@ -0,0 +1,72 @@ +from sympy.tensor import Indexed +from sympy.core.containers import Tuple +from sympy.core.symbol import Dummy +from sympy.core.sympify import sympify +from sympy.integrals.integrals import Integral + + +class IndexedIntegral(Integral): + """ + Experimental class to test integration by indexed variables. + + Usage is analogue to ``Integral``, it simply adds awareness of + integration over indices. + + Contraction of non-identical index symbols referring to the same + ``IndexedBase`` is not yet supported. + + Examples + ======== + + >>> from sympy.sandbox.indexed_integrals import IndexedIntegral + >>> from sympy import IndexedBase, symbols + >>> A = IndexedBase('A') + >>> i, j = symbols('i j', integer=True) + >>> ii = IndexedIntegral(A[i], A[i]) + >>> ii + Integral(_A[i], _A[i]) + >>> ii.doit() + A[i]**2/2 + + If the indices are different, indexed objects are considered to be + different variables: + + >>> i2 = IndexedIntegral(A[j], A[i]) + >>> i2 + Integral(A[j], _A[i]) + >>> i2.doit() + A[i]*A[j] + """ + + def __new__(cls, function, *limits, **assumptions): + repl, limits = IndexedIntegral._indexed_process_limits(limits) + function = sympify(function) + function = function.xreplace(repl) + obj = Integral.__new__(cls, function, *limits, **assumptions) + obj._indexed_repl = repl + obj._indexed_reverse_repl = {val: key for key, val in repl.items()} + return obj + + def doit(self): + res = super().doit() + return res.xreplace(self._indexed_reverse_repl) + + @staticmethod + def _indexed_process_limits(limits): + repl = {} + newlimits = [] + for i in limits: + if isinstance(i, (tuple, list, Tuple)): + v = i[0] + vrest = i[1:] + else: + v = i + vrest = () + if isinstance(v, Indexed): + if v not in repl: + r = Dummy(str(v)) + repl[v] = r + newlimits.append((r,)+vrest) + else: + newlimits.append(i) + return repl, newlimits diff --git a/.venv/lib/python3.13/site-packages/sympy/sets/__init__.py b/.venv/lib/python3.13/site-packages/sympy/sets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b909c0b5ef03b1e1e76dfbf4288f61860575da7 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/sets/__init__.py @@ -0,0 +1,36 @@ +from .sets import (Set, Interval, Union, FiniteSet, ProductSet, + Intersection, imageset, Complement, SymmetricDifference, + DisjointUnion) + +from .fancysets import ImageSet, Range, ComplexRegion +from .contains import Contains +from .conditionset import ConditionSet +from .ordinals import Ordinal, OmegaPower, ord0 +from .powerset import PowerSet +from ..core.singleton import S +from .handlers.comparison import _eval_is_eq # noqa:F401 +Complexes = S.Complexes +EmptySet = S.EmptySet +Integers = S.Integers +Naturals = S.Naturals +Naturals0 = S.Naturals0 +Rationals = S.Rationals +Reals = S.Reals +UniversalSet = S.UniversalSet + +__all__ = [ + 'Set', 'Interval', 'Union', 'EmptySet', 'FiniteSet', 'ProductSet', + 'Intersection', 'imageset', 'Complement', 'SymmetricDifference', 'DisjointUnion', + + 'ImageSet', 'Range', 'ComplexRegion', 'Reals', + + 'Contains', + + 'ConditionSet', + + 'Ordinal', 'OmegaPower', 'ord0', + + 'PowerSet', + + 'Reals', 'Naturals', 'Naturals0', 'UniversalSet', 'Integers', 'Rationals', +] diff --git a/.venv/lib/python3.13/site-packages/sympy/sets/conditionset.py b/.venv/lib/python3.13/site-packages/sympy/sets/conditionset.py new file mode 100644 index 0000000000000000000000000000000000000000..e847e60ce97d7e9922ce907042ace941838b0ab1 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/sets/conditionset.py @@ -0,0 +1,246 @@ +from sympy.core.singleton import S +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.function import Lambda, BadSignatureError +from sympy.core.logic import fuzzy_bool +from sympy.core.relational import Eq +from sympy.core.symbol import Dummy +from sympy.core.sympify import _sympify +from sympy.logic.boolalg import And, as_Boolean +from sympy.utilities.iterables import sift, flatten, has_dups +from sympy.utilities.exceptions import sympy_deprecation_warning +from .contains import Contains +from .sets import Set, Union, FiniteSet, SetKind + + +adummy = Dummy('conditionset') + + +class ConditionSet(Set): + r""" + Set of elements which satisfies a given condition. + + .. math:: \{x \mid \textrm{condition}(x) = \texttt{True}, x \in S\} + + Examples + ======== + + >>> from sympy import Symbol, S, ConditionSet, pi, Eq, sin, Interval + >>> from sympy.abc import x, y, z + + >>> sin_sols = ConditionSet(x, Eq(sin(x), 0), Interval(0, 2*pi)) + >>> 2*pi in sin_sols + True + >>> pi/2 in sin_sols + False + >>> 3*pi in sin_sols + False + >>> 5 in ConditionSet(x, x**2 > 4, S.Reals) + True + + If the value is not in the base set, the result is false: + + >>> 5 in ConditionSet(x, x**2 > 4, Interval(2, 4)) + False + + Notes + ===== + + Symbols with assumptions should be avoided or else the + condition may evaluate without consideration of the set: + + >>> n = Symbol('n', negative=True) + >>> cond = (n > 0); cond + False + >>> ConditionSet(n, cond, S.Integers) + EmptySet + + Only free symbols can be changed by using `subs`: + + >>> c = ConditionSet(x, x < 1, {x, z}) + >>> c.subs(x, y) + ConditionSet(x, x < 1, {y, z}) + + To check if ``pi`` is in ``c`` use: + + >>> pi in c + False + + If no base set is specified, the universal set is implied: + + >>> ConditionSet(x, x < 1).base_set + UniversalSet + + Only symbols or symbol-like expressions can be used: + + >>> ConditionSet(x + 1, x + 1 < 1, S.Integers) + Traceback (most recent call last): + ... + ValueError: non-symbol dummy not recognized in condition + + When the base set is a ConditionSet, the symbols will be + unified if possible with preference for the outermost symbols: + + >>> ConditionSet(x, x < y, ConditionSet(z, z + y < 2, S.Integers)) + ConditionSet(x, (x < y) & (x + y < 2), Integers) + + """ + def __new__(cls, sym, condition, base_set=S.UniversalSet): + sym = _sympify(sym) + flat = flatten([sym]) + if has_dups(flat): + raise BadSignatureError("Duplicate symbols detected") + base_set = _sympify(base_set) + if not isinstance(base_set, Set): + raise TypeError( + 'base set should be a Set object, not %s' % base_set) + condition = _sympify(condition) + + if isinstance(condition, FiniteSet): + condition_orig = condition + temp = (Eq(lhs, 0) for lhs in condition) + condition = And(*temp) + sympy_deprecation_warning( + f""" +Using a set for the condition in ConditionSet is deprecated. Use a boolean +instead. + +In this case, replace + + {condition_orig} + +with + + {condition} +""", + deprecated_since_version='1.5', + active_deprecations_target="deprecated-conditionset-set", + ) + + condition = as_Boolean(condition) + + if condition is S.true: + return base_set + + if condition is S.false: + return S.EmptySet + + if base_set is S.EmptySet: + return S.EmptySet + + # no simple answers, so now check syms + for i in flat: + if not getattr(i, '_diff_wrt', False): + raise ValueError('`%s` is not symbol-like' % i) + + if base_set.contains(sym) is S.false: + raise TypeError('sym `%s` is not in base_set `%s`' % (sym, base_set)) + + know = None + if isinstance(base_set, FiniteSet): + sifted = sift( + base_set, lambda _: fuzzy_bool(condition.subs(sym, _))) + if sifted[None]: + know = FiniteSet(*sifted[True]) + base_set = FiniteSet(*sifted[None]) + else: + return FiniteSet(*sifted[True]) + + if isinstance(base_set, cls): + s, c, b = base_set.args + def sig(s): + return cls(s, Eq(adummy, 0)).as_dummy().sym + sa, sb = map(sig, (sym, s)) + if sa != sb: + raise BadSignatureError('sym does not match sym of base set') + reps = dict(zip(flatten([sym]), flatten([s]))) + if s == sym: + condition = And(condition, c) + base_set = b + elif not c.free_symbols & sym.free_symbols: + reps = {v: k for k, v in reps.items()} + condition = And(condition, c.xreplace(reps)) + base_set = b + elif not condition.free_symbols & s.free_symbols: + sym = sym.xreplace(reps) + condition = And(condition.xreplace(reps), c) + base_set = b + + # flatten ConditionSet(Contains(ConditionSet())) expressions + if isinstance(condition, Contains) and (sym == condition.args[0]): + if isinstance(condition.args[1], Set): + return condition.args[1].intersect(base_set) + + rv = Basic.__new__(cls, sym, condition, base_set) + return rv if know is None else Union(know, rv) + + sym = property(lambda self: self.args[0]) + condition = property(lambda self: self.args[1]) + base_set = property(lambda self: self.args[2]) + + @property + def free_symbols(self): + cond_syms = self.condition.free_symbols - self.sym.free_symbols + return cond_syms | self.base_set.free_symbols + + @property + def bound_symbols(self): + return flatten([self.sym]) + + def _contains(self, other): + def ok_sig(a, b): + tuples = [isinstance(i, Tuple) for i in (a, b)] + c = tuples.count(True) + if c == 1: + return False + if c == 0: + return True + return len(a) == len(b) and all( + ok_sig(i, j) for i, j in zip(a, b)) + if not ok_sig(self.sym, other): + return S.false + + # try doing base_cond first and return + # False immediately if it is False + base_cond = Contains(other, self.base_set) + if base_cond is S.false: + return S.false + + # Substitute other into condition. This could raise e.g. for + # ConditionSet(x, 1/x >= 0, Reals).contains(0) + lamda = Lambda((self.sym,), self.condition) + try: + lambda_cond = lamda(other) + except TypeError: + return None + else: + return And(base_cond, lambda_cond) + + def as_relational(self, other): + f = Lambda(self.sym, self.condition) + if isinstance(self.sym, Tuple): + f = f(*other) + else: + f = f(other) + return And(f, self.base_set.contains(other)) + + def _eval_subs(self, old, new): + sym, cond, base = self.args + dsym = sym.subs(old, adummy) + insym = dsym.has(adummy) + # prioritize changing a symbol in the base + newbase = base.subs(old, new) + if newbase != base: + if not insym: + cond = cond.subs(old, new) + return self.func(sym, cond, newbase) + if insym: + pass # no change of bound symbols via subs + elif getattr(new, '_diff_wrt', False): + cond = cond.subs(old, new) + else: + pass # let error about the symbol raise from __new__ + return self.func(sym, cond, base) + + def _kind(self): + return SetKind(self.sym.kind) diff --git a/.venv/lib/python3.13/site-packages/sympy/sets/contains.py b/.venv/lib/python3.13/site-packages/sympy/sets/contains.py new file mode 100644 index 0000000000000000000000000000000000000000..403d4875279d718724a898efa5cba41bc7bed6ea --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/sets/contains.py @@ -0,0 +1,63 @@ +from sympy.core import S +from sympy.core.sympify import sympify +from sympy.core.relational import Eq, Ne +from sympy.core.parameters import global_parameters +from sympy.logic.boolalg import Boolean +from sympy.utilities.misc import func_name +from .sets import Set + + +class Contains(Boolean): + """ + Asserts that x is an element of the set S. + + Examples + ======== + + >>> from sympy import Symbol, Integer, S, Contains + >>> Contains(Integer(2), S.Integers) + True + >>> Contains(Integer(-2), S.Naturals) + False + >>> i = Symbol('i', integer=True) + >>> Contains(i, S.Naturals) + Contains(i, Naturals) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Element_%28mathematics%29 + """ + def __new__(cls, x, s, evaluate=None): + x = sympify(x) + s = sympify(s) + + if evaluate is None: + evaluate = global_parameters.evaluate + + if not isinstance(s, Set): + raise TypeError('expecting Set, not %s' % func_name(s)) + + if evaluate: + # _contains can return symbolic booleans that would be returned by + # s.contains(x) but here for Contains(x, s) we only evaluate to + # true, false or return the unevaluated Contains. + result = s._contains(x) + + if isinstance(result, Boolean): + if result in (S.true, S.false): + return result + elif result is not None: + raise TypeError("_contains() should return Boolean or None") + + return super().__new__(cls, x, s) + + @property + def binary_symbols(self): + return set().union(*[i.binary_symbols + for i in self.args[1].args + if i.is_Boolean or i.is_Symbol or + isinstance(i, (Eq, Ne))]) + + def as_set(self): + return self.args[1] diff --git a/.venv/lib/python3.13/site-packages/sympy/sets/fancysets.py b/.venv/lib/python3.13/site-packages/sympy/sets/fancysets.py new file mode 100644 index 0000000000000000000000000000000000000000..e0e24a2a864222d16ba1a697558b5211127fb2ad --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/sets/fancysets.py @@ -0,0 +1,1523 @@ +from functools import reduce +from itertools import product + +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.function import Lambda +from sympy.core.logic import fuzzy_not, fuzzy_or, fuzzy_and +from sympy.core.mod import Mod +from sympy.core.intfunc import igcd +from sympy.core.numbers import oo, Rational, Integer +from sympy.core.relational import Eq, is_eq +from sympy.core.kind import NumberKind +from sympy.core.singleton import Singleton, S +from sympy.core.symbol import Dummy, symbols, Symbol +from sympy.core.sympify import _sympify, sympify, _sympy_converter +from sympy.functions.elementary.integers import ceiling, floor +from sympy.functions.elementary.trigonometric import sin, cos +from sympy.logic.boolalg import And, Or +from .sets import tfn, Set, Interval, Union, FiniteSet, ProductSet, SetKind +from sympy.utilities.misc import filldedent + + +class Rationals(Set, metaclass=Singleton): + """ + Represents the rational numbers. This set is also available as + the singleton ``S.Rationals``. + + Examples + ======== + + >>> from sympy import S + >>> S.Half in S.Rationals + True + >>> iterable = iter(S.Rationals) + >>> [next(iterable) for i in range(12)] + [0, 1, -1, 1/2, 2, -1/2, -2, 1/3, 3, -1/3, -3, 2/3] + """ + + is_iterable = True + _inf = S.NegativeInfinity + _sup = S.Infinity + is_empty = False + is_finite_set = False + + def _contains(self, other): + if not isinstance(other, Expr): + return S.false + return tfn[other.is_rational] + + def __iter__(self): + yield S.Zero + yield S.One + yield S.NegativeOne + d = 2 + while True: + for n in range(d): + if igcd(n, d) == 1: + yield Rational(n, d) + yield Rational(d, n) + yield Rational(-n, d) + yield Rational(-d, n) + d += 1 + + @property + def _boundary(self): + return S.Reals + + def _kind(self): + return SetKind(NumberKind) + + +class Naturals(Set, metaclass=Singleton): + """ + Represents the natural numbers (or counting numbers) which are all + positive integers starting from 1. This set is also available as + the singleton ``S.Naturals``. + + Examples + ======== + + >>> from sympy import S, Interval, pprint + >>> 5 in S.Naturals + True + >>> iterable = iter(S.Naturals) + >>> next(iterable) + 1 + >>> next(iterable) + 2 + >>> next(iterable) + 3 + >>> pprint(S.Naturals.intersect(Interval(0, 10))) + {1, 2, ..., 10} + + See Also + ======== + + Naturals0 : non-negative integers (i.e. includes 0, too) + Integers : also includes negative integers + """ + + is_iterable = True + _inf: Integer = S.One + _sup = S.Infinity + is_empty = False + is_finite_set = False + + def _contains(self, other): + if not isinstance(other, Expr): + return S.false + elif other.is_positive and other.is_integer: + return S.true + elif other.is_integer is False or other.is_positive is False: + return S.false + + def _eval_is_subset(self, other): + return Range(1, oo).is_subset(other) + + def _eval_is_superset(self, other): + return Range(1, oo).is_superset(other) + + def __iter__(self): + i = self._inf + while True: + yield i + i = i + 1 + + @property + def _boundary(self): + return self + + def as_relational(self, x): + return And(Eq(floor(x), x), x >= self.inf, x < oo) + + def _kind(self): + return SetKind(NumberKind) + + +class Naturals0(Naturals): + """Represents the whole numbers which are all the non-negative integers, + inclusive of zero. + + See Also + ======== + + Naturals : positive integers; does not include 0 + Integers : also includes the negative integers + """ + _inf = S.Zero + + def _contains(self, other): + if not isinstance(other, Expr): + return S.false + elif other.is_integer and other.is_nonnegative: + return S.true + elif other.is_integer is False or other.is_nonnegative is False: + return S.false + + def _eval_is_subset(self, other): + return Range(oo).is_subset(other) + + def _eval_is_superset(self, other): + return Range(oo).is_superset(other) + + +class Integers(Set, metaclass=Singleton): + """ + Represents all integers: positive, negative and zero. This set is also + available as the singleton ``S.Integers``. + + Examples + ======== + + >>> from sympy import S, Interval, pprint + >>> 5 in S.Naturals + True + >>> iterable = iter(S.Integers) + >>> next(iterable) + 0 + >>> next(iterable) + 1 + >>> next(iterable) + -1 + >>> next(iterable) + 2 + + >>> pprint(S.Integers.intersect(Interval(-4, 4))) + {-4, -3, ..., 4} + + See Also + ======== + + Naturals0 : non-negative integers + Integers : positive and negative integers and zero + """ + + is_iterable = True + is_empty = False + is_finite_set = False + + def _contains(self, other): + if not isinstance(other, Expr): + return S.false + return tfn[other.is_integer] + + def __iter__(self): + yield S.Zero + i = S.One + while True: + yield i + yield -i + i = i + 1 + + @property + def _inf(self): + return S.NegativeInfinity + + @property + def _sup(self): + return S.Infinity + + @property + def _boundary(self): + return self + + def _kind(self): + return SetKind(NumberKind) + + def as_relational(self, x): + return And(Eq(floor(x), x), -oo < x, x < oo) + + def _eval_is_subset(self, other): + return Range(-oo, oo).is_subset(other) + + def _eval_is_superset(self, other): + return Range(-oo, oo).is_superset(other) + + +class Reals(Interval, metaclass=Singleton): + """ + Represents all real numbers + from negative infinity to positive infinity, + including all integer, rational and irrational numbers. + This set is also available as the singleton ``S.Reals``. + + + Examples + ======== + + >>> from sympy import S, Rational, pi, I + >>> 5 in S.Reals + True + >>> Rational(-1, 2) in S.Reals + True + >>> pi in S.Reals + True + >>> 3*I in S.Reals + False + >>> S.Reals.contains(pi) + True + + + See Also + ======== + + ComplexRegion + """ + @property + def start(self): + return S.NegativeInfinity + + @property + def end(self): + return S.Infinity + + @property + def left_open(self): + return True + + @property + def right_open(self): + return True + + def __eq__(self, other): + return other == Interval(S.NegativeInfinity, S.Infinity) + + def __hash__(self): + return hash(Interval(S.NegativeInfinity, S.Infinity)) + + +class ImageSet(Set): + """ + Image of a set under a mathematical function. The transformation + must be given as a Lambda function which has as many arguments + as the elements of the set upon which it operates, e.g. 1 argument + when acting on the set of integers or 2 arguments when acting on + a complex region. + + This function is not normally called directly, but is called + from ``imageset``. + + + Examples + ======== + + >>> from sympy import Symbol, S, pi, Dummy, Lambda + >>> from sympy import FiniteSet, ImageSet, Interval + + >>> x = Symbol('x') + >>> N = S.Naturals + >>> squares = ImageSet(Lambda(x, x**2), N) # {x**2 for x in N} + >>> 4 in squares + True + >>> 5 in squares + False + + >>> FiniteSet(0, 1, 2, 3, 4, 5, 6, 7, 9, 10).intersect(squares) + {1, 4, 9} + + >>> square_iterable = iter(squares) + >>> for i in range(4): + ... next(square_iterable) + 1 + 4 + 9 + 16 + + If you want to get value for `x` = 2, 1/2 etc. (Please check whether the + `x` value is in ``base_set`` or not before passing it as args) + + >>> squares.lamda(2) + 4 + >>> squares.lamda(S(1)/2) + 1/4 + + >>> n = Dummy('n') + >>> solutions = ImageSet(Lambda(n, n*pi), S.Integers) # solutions of sin(x) = 0 + >>> dom = Interval(-1, 1) + >>> dom.intersect(solutions) + {0} + + See Also + ======== + + sympy.sets.sets.imageset + """ + def __new__(cls, flambda, *sets): + if not isinstance(flambda, Lambda): + raise ValueError('First argument must be a Lambda') + + signature = flambda.signature + + if len(signature) != len(sets): + raise ValueError('Incompatible signature') + + sets = [_sympify(s) for s in sets] + + if not all(isinstance(s, Set) for s in sets): + raise TypeError("Set arguments to ImageSet should of type Set") + + if not all(cls._check_sig(sg, st) for sg, st in zip(signature, sets)): + raise ValueError("Signature %s does not match sets %s" % (signature, sets)) + + if flambda is S.IdentityFunction and len(sets) == 1: + return sets[0] + + if not set(flambda.variables) & flambda.expr.free_symbols: + is_empty = fuzzy_or(s.is_empty for s in sets) + if is_empty == True: + return S.EmptySet + elif is_empty == False: + return FiniteSet(flambda.expr) + + return Basic.__new__(cls, flambda, *sets) + + lamda = property(lambda self: self.args[0]) + base_sets = property(lambda self: self.args[1:]) + + @property + def base_set(self): + # XXX: Maybe deprecate this? It is poorly defined in handling + # the multivariate case... + sets = self.base_sets + if len(sets) == 1: + return sets[0] + else: + return ProductSet(*sets).flatten() + + @property + def base_pset(self): + return ProductSet(*self.base_sets) + + @classmethod + def _check_sig(cls, sig_i, set_i): + if sig_i.is_symbol: + return True + elif isinstance(set_i, ProductSet): + sets = set_i.sets + if len(sig_i) != len(sets): + return False + # Recurse through the signature for nested tuples: + return all(cls._check_sig(ts, ps) for ts, ps in zip(sig_i, sets)) + else: + # XXX: Need a better way of checking whether a set is a set of + # Tuples or not. For example a FiniteSet can contain Tuples + # but so can an ImageSet or a ConditionSet. Others like + # Integers, Reals etc can not contain Tuples. We could just + # list the possibilities here... Current code for e.g. + # _contains probably only works for ProductSet. + return True # Give the benefit of the doubt + + def __iter__(self): + already_seen = set() + for i in self.base_pset: + val = self.lamda(*i) + if val in already_seen: + continue + else: + already_seen.add(val) + yield val + + def _is_multivariate(self): + return len(self.lamda.variables) > 1 + + def _contains(self, other): + from sympy.solvers.solveset import _solveset_multi + + def get_symsetmap(signature, base_sets): + '''Attempt to get a map of symbols to base_sets''' + queue = list(zip(signature, base_sets)) + symsetmap = {} + for sig, base_set in queue: + if sig.is_symbol: + symsetmap[sig] = base_set + elif base_set.is_ProductSet: + sets = base_set.sets + if len(sig) != len(sets): + raise ValueError("Incompatible signature") + # Recurse + queue.extend(zip(sig, sets)) + else: + # If we get here then we have something like sig = (x, y) and + # base_set = {(1, 2), (3, 4)}. For now we give up. + return None + + return symsetmap + + def get_equations(expr, candidate): + '''Find the equations relating symbols in expr and candidate.''' + queue = [(expr, candidate)] + for e, c in queue: + if not isinstance(e, Tuple): + yield Eq(e, c) + elif not isinstance(c, Tuple) or len(e) != len(c): + yield False + return + else: + queue.extend(zip(e, c)) + + # Get the basic objects together: + other = _sympify(other) + expr = self.lamda.expr + sig = self.lamda.signature + variables = self.lamda.variables + base_sets = self.base_sets + + # Use dummy symbols for ImageSet parameters so they don't match + # anything in other + rep = {v: Dummy(v.name) for v in variables} + variables = [v.subs(rep) for v in variables] + sig = sig.subs(rep) + expr = expr.subs(rep) + + # Map the parts of other to those in the Lambda expr + equations = [] + for eq in get_equations(expr, other): + # Unsatisfiable equation? + if eq is False: + return S.false + equations.append(eq) + + # Map the symbols in the signature to the corresponding domains + symsetmap = get_symsetmap(sig, base_sets) + if symsetmap is None: + # Can't factor the base sets to a ProductSet + return None + + # Which of the variables in the Lambda signature need to be solved for? + symss = (eq.free_symbols for eq in equations) + variables = set(variables) & reduce(set.union, symss, set()) + + # Use internal multivariate solveset + variables = tuple(variables) + base_sets = [symsetmap[v] for v in variables] + solnset = _solveset_multi(equations, variables, base_sets) + if solnset is None: + return None + return tfn[fuzzy_not(solnset.is_empty)] + + @property + def is_iterable(self): + return all(s.is_iterable for s in self.base_sets) + + def doit(self, **hints): + from sympy.sets.setexpr import SetExpr + f = self.lamda + sig = f.signature + if len(sig) == 1 and sig[0].is_symbol and isinstance(f.expr, Expr): + base_set = self.base_sets[0] + return SetExpr(base_set)._eval_func(f).set + if all(s.is_FiniteSet for s in self.base_sets): + return FiniteSet(*(f(*a) for a in product(*self.base_sets))) + return self + + def _kind(self): + return SetKind(self.lamda.expr.kind) + + +class Range(Set): + """ + Represents a range of integers. Can be called as ``Range(stop)``, + ``Range(start, stop)``, or ``Range(start, stop, step)``; when ``step`` is + not given it defaults to 1. + + ``Range(stop)`` is the same as ``Range(0, stop, 1)`` and the stop value + (just as for Python ranges) is not included in the Range values. + + >>> from sympy import Range + >>> list(Range(3)) + [0, 1, 2] + + The step can also be negative: + + >>> list(Range(10, 0, -2)) + [10, 8, 6, 4, 2] + + The stop value is made canonical so equivalent ranges always + have the same args: + + >>> Range(0, 10, 3) + Range(0, 12, 3) + + Infinite ranges are allowed. ``oo`` and ``-oo`` are never included in the + set (``Range`` is always a subset of ``Integers``). If the starting point + is infinite, then the final value is ``stop - step``. To iterate such a + range, it needs to be reversed: + + >>> from sympy import oo + >>> r = Range(-oo, 1) + >>> r[-1] + 0 + >>> next(iter(r)) + Traceback (most recent call last): + ... + TypeError: Cannot iterate over Range with infinite start + >>> next(iter(r.reversed)) + 0 + + Although ``Range`` is a :class:`Set` (and supports the normal set + operations) it maintains the order of the elements and can + be used in contexts where ``range`` would be used. + + >>> from sympy import Interval + >>> Range(0, 10, 2).intersect(Interval(3, 7)) + Range(4, 8, 2) + >>> list(_) + [4, 6] + + Although slicing of a Range will always return a Range -- possibly + empty -- an empty set will be returned from any intersection that + is empty: + + >>> Range(3)[:0] + Range(0, 0, 1) + >>> Range(3).intersect(Interval(4, oo)) + EmptySet + >>> Range(3).intersect(Range(4, oo)) + EmptySet + + Range will accept symbolic arguments but has very limited support + for doing anything other than displaying the Range: + + >>> from sympy import Symbol, pprint + >>> from sympy.abc import i, j, k + >>> Range(i, j, k).start + i + >>> Range(i, j, k).inf + Traceback (most recent call last): + ... + ValueError: invalid method for symbolic range + + Better success will be had when using integer symbols: + + >>> n = Symbol('n', integer=True) + >>> r = Range(n, n + 20, 3) + >>> r.inf + n + >>> pprint(r) + {n, n + 3, ..., n + 18} + """ + + def __new__(cls, *args): + if len(args) == 1: + if isinstance(args[0], range): + raise TypeError( + 'use sympify(%s) to convert range to Range' % args[0]) + + # expand range + slc = slice(*args) + + if slc.step == 0: + raise ValueError("step cannot be 0") + + start, stop, step = slc.start or 0, slc.stop, slc.step or 1 + try: + ok = [] + for w in (start, stop, step): + w = sympify(w) + if w in [S.NegativeInfinity, S.Infinity] or ( + w.has(Symbol) and w.is_integer != False): + ok.append(w) + elif not w.is_Integer: + if w.is_infinite: + raise ValueError('infinite symbols not allowed') + raise ValueError + else: + ok.append(w) + except ValueError: + raise ValueError(filldedent(''' + Finite arguments to Range must be integers; `imageset` can define + other cases, e.g. use `imageset(i, i/10, Range(3))` to give + [0, 1/10, 1/5].''')) + start, stop, step = ok + + null = False + if any(i.has(Symbol) for i in (start, stop, step)): + dif = stop - start + n = dif/step + if n.is_Rational: + if dif == 0: + null = True + else: # (x, x + 5, 2) or (x, 3*x, x) + n = floor(n) + end = start + n*step + if dif.is_Rational: # (x, x + 5, 2) + if (end - stop).is_negative: + end += step + else: # (x, 3*x, x) + if (end/stop - 1).is_negative: + end += step + elif n.is_extended_negative: + null = True + else: + end = stop # other methods like sup and reversed must fail + elif start.is_infinite: + span = step*(stop - start) + if span is S.NaN or span <= 0: + null = True + elif step.is_Integer and stop.is_infinite and abs(step) != 1: + raise ValueError(filldedent(''' + Step size must be %s in this case.''' % (1 if step > 0 else -1))) + else: + end = stop + else: + oostep = step.is_infinite + if oostep: + step = S.One if step > 0 else S.NegativeOne + n = ceiling((stop - start)/step) + if n <= 0: + null = True + elif oostep: + step = S.One # make it canonical + end = start + step + else: + end = start + n*step + if null: + start = end = S.Zero + step = S.One + return Basic.__new__(cls, start, end, step) + + start = property(lambda self: self.args[0]) + stop = property(lambda self: self.args[1]) + step = property(lambda self: self.args[2]) + + @property + def reversed(self): + """Return an equivalent Range in the opposite order. + + Examples + ======== + + >>> from sympy import Range + >>> Range(10).reversed + Range(9, -1, -1) + """ + if self.has(Symbol): + n = (self.stop - self.start)/self.step + if not n.is_extended_positive or not all( + i.is_integer or i.is_infinite for i in self.args): + raise ValueError('invalid method for symbolic range') + if self.start == self.stop: + return self + return self.func( + self.stop - self.step, self.start - self.step, -self.step) + + def _kind(self): + return SetKind(NumberKind) + + def _contains(self, other): + if self.start == self.stop: + return S.false + if other.is_infinite: + return S.false + if not other.is_integer: + return tfn[other.is_integer] + if self.has(Symbol): + n = (self.stop - self.start)/self.step + if not n.is_extended_positive or not all( + i.is_integer or i.is_infinite for i in self.args): + return + else: + n = self.size + if self.start.is_finite: + ref = self.start + elif self.stop.is_finite: + ref = self.stop + else: # both infinite; step is +/- 1 (enforced by __new__) + return S.true + if n == 1: + return Eq(other, self[0]) + res = (ref - other) % self.step + if res == S.Zero: + if self.has(Symbol): + d = Dummy('i') + return self.as_relational(d).subs(d, other) + return And(other >= self.inf, other <= self.sup) + elif res.is_Integer: # off sequence + return S.false + else: # symbolic/unsimplified residue modulo step + return None + + def __iter__(self): + n = self.size # validate + if not (n.has(S.Infinity) or n.has(S.NegativeInfinity) or n.is_Integer): + raise TypeError("Cannot iterate over symbolic Range") + if self.start in [S.NegativeInfinity, S.Infinity]: + raise TypeError("Cannot iterate over Range with infinite start") + elif self.start != self.stop: + i = self.start + if n.is_infinite: + while True: + yield i + i += self.step + else: + for _ in range(n): + yield i + i += self.step + + @property + def is_iterable(self): + # Check that size can be determined, used by __iter__ + dif = self.stop - self.start + n = dif/self.step + if not (n.has(S.Infinity) or n.has(S.NegativeInfinity) or n.is_Integer): + return False + if self.start in [S.NegativeInfinity, S.Infinity]: + return False + if not (n.is_extended_nonnegative and all(i.is_integer for i in self.args)): + return False + return True + + def __len__(self): + rv = self.size + if rv is S.Infinity: + raise ValueError('Use .size to get the length of an infinite Range') + return int(rv) + + @property + def size(self): + if self.start == self.stop: + return S.Zero + dif = self.stop - self.start + n = dif/self.step + if n.is_infinite: + return S.Infinity + if n.is_extended_nonnegative and all(i.is_integer for i in self.args): + return abs(floor(n)) + raise ValueError('Invalid method for symbolic Range') + + @property + def is_finite_set(self): + if self.start.is_integer and self.stop.is_integer: + return True + return self.size.is_finite + + @property + def is_empty(self): + try: + return self.size.is_zero + except ValueError: + return None + + def __bool__(self): + # this only distinguishes between definite null range + # and non-null/unknown null; getting True doesn't mean + # that it actually is not null + b = is_eq(self.start, self.stop) + if b is None: + raise ValueError('cannot tell if Range is null or not') + return not bool(b) + + def __getitem__(self, i): + ooslice = "cannot slice from the end with an infinite value" + zerostep = "slice step cannot be zero" + infinite = "slicing not possible on range with infinite start" + # if we had to take every other element in the following + # oo, ..., 6, 4, 2, 0 + # we might get oo, ..., 4, 0 or oo, ..., 6, 2 + ambiguous = "cannot unambiguously re-stride from the end " + \ + "with an infinite value" + if isinstance(i, slice): + if self.size.is_finite: # validates, too + if self.start == self.stop: + return Range(0) + start, stop, step = i.indices(self.size) + n = ceiling((stop - start)/step) + if n <= 0: + return Range(0) + canonical_stop = start + n*step + end = canonical_stop - step + ss = step*self.step + return Range(self[start], self[end] + ss, ss) + else: # infinite Range + start = i.start + stop = i.stop + if i.step == 0: + raise ValueError(zerostep) + step = i.step or 1 + ss = step*self.step + #--------------------- + # handle infinite Range + # i.e. Range(-oo, oo) or Range(oo, -oo, -1) + # -------------------- + if self.start.is_infinite and self.stop.is_infinite: + raise ValueError(infinite) + #--------------------- + # handle infinite on right + # e.g. Range(0, oo) or Range(0, -oo, -1) + # -------------------- + if self.stop.is_infinite: + # start and stop are not interdependent -- + # they only depend on step --so we use the + # equivalent reversed values + return self.reversed[ + stop if stop is None else -stop + 1: + start if start is None else -start: + step].reversed + #--------------------- + # handle infinite on the left + # e.g. Range(oo, 0, -1) or Range(-oo, 0) + # -------------------- + # consider combinations of + # start/stop {== None, < 0, == 0, > 0} and + # step {< 0, > 0} + if start is None: + if stop is None: + if step < 0: + return Range(self[-1], self.start, ss) + elif step > 1: + raise ValueError(ambiguous) + else: # == 1 + return self + elif stop < 0: + if step < 0: + return Range(self[-1], self[stop], ss) + else: # > 0 + return Range(self.start, self[stop], ss) + elif stop == 0: + if step > 0: + return Range(0) + else: # < 0 + raise ValueError(ooslice) + elif stop == 1: + if step > 0: + raise ValueError(ooslice) # infinite singleton + else: # < 0 + raise ValueError(ooslice) + else: # > 1 + raise ValueError(ooslice) + elif start < 0: + if stop is None: + if step < 0: + return Range(self[start], self.start, ss) + else: # > 0 + return Range(self[start], self.stop, ss) + elif stop < 0: + return Range(self[start], self[stop], ss) + elif stop == 0: + if step < 0: + raise ValueError(ooslice) + else: # > 0 + return Range(0) + elif stop > 0: + raise ValueError(ooslice) + elif start == 0: + if stop is None: + if step < 0: + raise ValueError(ooslice) # infinite singleton + elif step > 1: + raise ValueError(ambiguous) + else: # == 1 + return self + elif stop < 0: + if step > 1: + raise ValueError(ambiguous) + elif step == 1: + return Range(self.start, self[stop], ss) + else: # < 0 + return Range(0) + else: # >= 0 + raise ValueError(ooslice) + elif start > 0: + raise ValueError(ooslice) + else: + if self.start == self.stop: + raise IndexError('Range index out of range') + if not (all(i.is_integer or i.is_infinite + for i in self.args) and ((self.stop - self.start)/ + self.step).is_extended_positive): + raise ValueError('Invalid method for symbolic Range') + if i == 0: + if self.start.is_infinite: + raise ValueError(ooslice) + return self.start + if i == -1: + if self.stop.is_infinite: + raise ValueError(ooslice) + return self.stop - self.step + n = self.size # must be known for any other index + rv = (self.stop if i < 0 else self.start) + i*self.step + if rv.is_infinite: + raise ValueError(ooslice) + val = (rv - self.start)/self.step + rel = fuzzy_or([val.is_infinite, + fuzzy_and([val.is_nonnegative, (n-val).is_nonnegative])]) + if rel: + return rv + if rel is None: + raise ValueError('Invalid method for symbolic Range') + raise IndexError("Range index out of range") + + @property + def _inf(self): + if not self: + return S.EmptySet.inf + if self.has(Symbol): + if all(i.is_integer or i.is_infinite for i in self.args): + dif = self.stop - self.start + if self.step.is_positive and dif.is_positive: + return self.start + elif self.step.is_negative and dif.is_negative: + return self.stop - self.step + raise ValueError('invalid method for symbolic range') + if self.step > 0: + return self.start + else: + return self.stop - self.step + + @property + def _sup(self): + if not self: + return S.EmptySet.sup + if self.has(Symbol): + if all(i.is_integer or i.is_infinite for i in self.args): + dif = self.stop - self.start + if self.step.is_positive and dif.is_positive: + return self.stop - self.step + elif self.step.is_negative and dif.is_negative: + return self.start + raise ValueError('invalid method for symbolic range') + if self.step > 0: + return self.stop - self.step + else: + return self.start + + @property + def _boundary(self): + return self + + def as_relational(self, x): + """Rewrite a Range in terms of equalities and logic operators. """ + if self.start.is_infinite: + assert not self.stop.is_infinite # by instantiation + a = self.reversed.start + else: + a = self.start + step = self.step + in_seq = Eq(Mod(x - a, step), 0) + ints = And(Eq(Mod(a, 1), 0), Eq(Mod(step, 1), 0)) + n = (self.stop - self.start)/self.step + if n == 0: + return S.EmptySet.as_relational(x) + if n == 1: + return And(Eq(x, a), ints) + try: + a, b = self.inf, self.sup + except ValueError: + a = None + if a is not None: + range_cond = And( + x > a if a.is_infinite else x >= a, + x < b if b.is_infinite else x <= b) + else: + a, b = self.start, self.stop - self.step + range_cond = Or( + And(self.step >= 1, x > a if a.is_infinite else x >= a, + x < b if b.is_infinite else x <= b), + And(self.step <= -1, x < a if a.is_infinite else x <= a, + x > b if b.is_infinite else x >= b)) + return And(in_seq, ints, range_cond) + + +_sympy_converter[range] = lambda r: Range(r.start, r.stop, r.step) + +def normalize_theta_set(theta): + r""" + Normalize a Real Set `theta` in the interval `[0, 2\pi)`. It returns + a normalized value of theta in the Set. For Interval, a maximum of + one cycle $[0, 2\pi]$, is returned i.e. for theta equal to $[0, 10\pi]$, + returned normalized value would be $[0, 2\pi)$. As of now intervals + with end points as non-multiples of ``pi`` is not supported. + + Raises + ====== + + NotImplementedError + The algorithms for Normalizing theta Set are not yet + implemented. + ValueError + The input is not valid, i.e. the input is not a real set. + RuntimeError + It is a bug, please report to the github issue tracker. + + Examples + ======== + + >>> from sympy.sets.fancysets import normalize_theta_set + >>> from sympy import Interval, FiniteSet, pi + >>> normalize_theta_set(Interval(9*pi/2, 5*pi)) + Interval(pi/2, pi) + >>> normalize_theta_set(Interval(-3*pi/2, pi/2)) + Interval.Ropen(0, 2*pi) + >>> normalize_theta_set(Interval(-pi/2, pi/2)) + Union(Interval(0, pi/2), Interval.Ropen(3*pi/2, 2*pi)) + >>> normalize_theta_set(Interval(-4*pi, 3*pi)) + Interval.Ropen(0, 2*pi) + >>> normalize_theta_set(Interval(-3*pi/2, -pi/2)) + Interval(pi/2, 3*pi/2) + >>> normalize_theta_set(FiniteSet(0, pi, 3*pi)) + {0, pi} + + """ + from sympy.functions.elementary.trigonometric import _pi_coeff + + if theta.is_Interval: + interval_len = theta.measure + # one complete circle + if interval_len >= 2*S.Pi: + if interval_len == 2*S.Pi and theta.left_open and theta.right_open: + k = _pi_coeff(theta.start) + return Union(Interval(0, k*S.Pi, False, True), + Interval(k*S.Pi, 2*S.Pi, True, True)) + return Interval(0, 2*S.Pi, False, True) + + k_start, k_end = _pi_coeff(theta.start), _pi_coeff(theta.end) + + if k_start is None or k_end is None: + raise NotImplementedError("Normalizing theta without pi as coefficient is " + "not yet implemented") + new_start = k_start*S.Pi + new_end = k_end*S.Pi + + if new_start > new_end: + return Union(Interval(S.Zero, new_end, False, theta.right_open), + Interval(new_start, 2*S.Pi, theta.left_open, True)) + else: + return Interval(new_start, new_end, theta.left_open, theta.right_open) + + elif theta.is_FiniteSet: + new_theta = [] + for element in theta: + k = _pi_coeff(element) + if k is None: + raise NotImplementedError('Normalizing theta without pi as ' + 'coefficient, is not Implemented.') + else: + new_theta.append(k*S.Pi) + return FiniteSet(*new_theta) + + elif theta.is_Union: + return Union(*[normalize_theta_set(interval) for interval in theta.args]) + + elif theta.is_subset(S.Reals): + raise NotImplementedError("Normalizing theta when, it is of type %s is not " + "implemented" % type(theta)) + else: + raise ValueError(" %s is not a real set" % (theta)) + + +class ComplexRegion(Set): + r""" + Represents the Set of all Complex Numbers. It can represent a + region of Complex Plane in both the standard forms Polar and + Rectangular coordinates. + + * Polar Form + Input is in the form of the ProductSet or Union of ProductSets + of the intervals of ``r`` and ``theta``, and use the flag ``polar=True``. + + .. math:: Z = \{z \in \mathbb{C} \mid z = r\times (\cos(\theta) + I\sin(\theta)), r \in [\texttt{r}], \theta \in [\texttt{theta}]\} + + * Rectangular Form + Input is in the form of the ProductSet or Union of ProductSets + of interval of x and y, the real and imaginary parts of the Complex numbers in a plane. + Default input type is in rectangular form. + + .. math:: Z = \{z \in \mathbb{C} \mid z = x + Iy, x \in [\operatorname{re}(z)], y \in [\operatorname{im}(z)]\} + + Examples + ======== + + >>> from sympy import ComplexRegion, Interval, S, I, Union + >>> a = Interval(2, 3) + >>> b = Interval(4, 6) + >>> c1 = ComplexRegion(a*b) # Rectangular Form + >>> c1 + CartesianComplexRegion(ProductSet(Interval(2, 3), Interval(4, 6))) + + * c1 represents the rectangular region in complex plane + surrounded by the coordinates (2, 4), (3, 4), (3, 6) and + (2, 6), of the four vertices. + + >>> c = Interval(1, 8) + >>> c2 = ComplexRegion(Union(a*b, b*c)) + >>> c2 + CartesianComplexRegion(Union(ProductSet(Interval(2, 3), Interval(4, 6)), ProductSet(Interval(4, 6), Interval(1, 8)))) + + * c2 represents the Union of two rectangular regions in complex + plane. One of them surrounded by the coordinates of c1 and + other surrounded by the coordinates (4, 1), (6, 1), (6, 8) and + (4, 8). + + >>> 2.5 + 4.5*I in c1 + True + >>> 2.5 + 6.5*I in c1 + False + + >>> r = Interval(0, 1) + >>> theta = Interval(0, 2*S.Pi) + >>> c2 = ComplexRegion(r*theta, polar=True) # Polar Form + >>> c2 # unit Disk + PolarComplexRegion(ProductSet(Interval(0, 1), Interval.Ropen(0, 2*pi))) + + * c2 represents the region in complex plane inside the + Unit Disk centered at the origin. + + >>> 0.5 + 0.5*I in c2 + True + >>> 1 + 2*I in c2 + False + + >>> unit_disk = ComplexRegion(Interval(0, 1)*Interval(0, 2*S.Pi), polar=True) + >>> upper_half_unit_disk = ComplexRegion(Interval(0, 1)*Interval(0, S.Pi), polar=True) + >>> intersection = unit_disk.intersect(upper_half_unit_disk) + >>> intersection + PolarComplexRegion(ProductSet(Interval(0, 1), Interval(0, pi))) + >>> intersection == upper_half_unit_disk + True + + See Also + ======== + + CartesianComplexRegion + PolarComplexRegion + Complexes + + """ + is_ComplexRegion = True + + def __new__(cls, sets, polar=False): + if polar is False: + return CartesianComplexRegion(sets) + elif polar is True: + return PolarComplexRegion(sets) + else: + raise ValueError("polar should be either True or False") + + @property + def sets(self): + """ + Return raw input sets to the self. + + Examples + ======== + + >>> from sympy import Interval, ComplexRegion, Union + >>> a = Interval(2, 3) + >>> b = Interval(4, 5) + >>> c = Interval(1, 7) + >>> C1 = ComplexRegion(a*b) + >>> C1.sets + ProductSet(Interval(2, 3), Interval(4, 5)) + >>> C2 = ComplexRegion(Union(a*b, b*c)) + >>> C2.sets + Union(ProductSet(Interval(2, 3), Interval(4, 5)), ProductSet(Interval(4, 5), Interval(1, 7))) + + """ + return self.args[0] + + @property + def psets(self): + """ + Return a tuple of sets (ProductSets) input of the self. + + Examples + ======== + + >>> from sympy import Interval, ComplexRegion, Union + >>> a = Interval(2, 3) + >>> b = Interval(4, 5) + >>> c = Interval(1, 7) + >>> C1 = ComplexRegion(a*b) + >>> C1.psets + (ProductSet(Interval(2, 3), Interval(4, 5)),) + >>> C2 = ComplexRegion(Union(a*b, b*c)) + >>> C2.psets + (ProductSet(Interval(2, 3), Interval(4, 5)), ProductSet(Interval(4, 5), Interval(1, 7))) + + """ + if self.sets.is_ProductSet: + psets = () + psets = psets + (self.sets, ) + else: + psets = self.sets.args + return psets + + @property + def a_interval(self): + """ + Return the union of intervals of `x` when, self is in + rectangular form, or the union of intervals of `r` when + self is in polar form. + + Examples + ======== + + >>> from sympy import Interval, ComplexRegion, Union + >>> a = Interval(2, 3) + >>> b = Interval(4, 5) + >>> c = Interval(1, 7) + >>> C1 = ComplexRegion(a*b) + >>> C1.a_interval + Interval(2, 3) + >>> C2 = ComplexRegion(Union(a*b, b*c)) + >>> C2.a_interval + Union(Interval(2, 3), Interval(4, 5)) + + """ + a_interval = [] + for element in self.psets: + a_interval.append(element.args[0]) + + a_interval = Union(*a_interval) + return a_interval + + @property + def b_interval(self): + """ + Return the union of intervals of `y` when, self is in + rectangular form, or the union of intervals of `theta` + when self is in polar form. + + Examples + ======== + + >>> from sympy import Interval, ComplexRegion, Union + >>> a = Interval(2, 3) + >>> b = Interval(4, 5) + >>> c = Interval(1, 7) + >>> C1 = ComplexRegion(a*b) + >>> C1.b_interval + Interval(4, 5) + >>> C2 = ComplexRegion(Union(a*b, b*c)) + >>> C2.b_interval + Interval(1, 7) + + """ + b_interval = [] + for element in self.psets: + b_interval.append(element.args[1]) + + b_interval = Union(*b_interval) + return b_interval + + @property + def _measure(self): + """ + The measure of self.sets. + + Examples + ======== + + >>> from sympy import Interval, ComplexRegion, S + >>> a, b = Interval(2, 5), Interval(4, 8) + >>> c = Interval(0, 2*S.Pi) + >>> c1 = ComplexRegion(a*b) + >>> c1.measure + 12 + >>> c2 = ComplexRegion(a*c, polar=True) + >>> c2.measure + 6*pi + + """ + return self.sets._measure + + def _kind(self): + return self.args[0].kind + + @classmethod + def from_real(cls, sets): + """ + Converts given subset of real numbers to a complex region. + + Examples + ======== + + >>> from sympy import Interval, ComplexRegion + >>> unit = Interval(0,1) + >>> ComplexRegion.from_real(unit) + CartesianComplexRegion(ProductSet(Interval(0, 1), {0})) + + """ + if not sets.is_subset(S.Reals): + raise ValueError("sets must be a subset of the real line") + + return CartesianComplexRegion(sets * FiniteSet(0)) + + def _contains(self, other): + from sympy.functions import arg, Abs + + isTuple = isinstance(other, Tuple) + if isTuple and len(other) != 2: + raise ValueError('expecting Tuple of length 2') + + # If the other is not an Expression, and neither a Tuple + if not isinstance(other, (Expr, Tuple)): + return S.false + + # self in rectangular form + if not self.polar: + re, im = other if isTuple else other.as_real_imag() + return tfn[fuzzy_or(fuzzy_and([ + pset.args[0]._contains(re), + pset.args[1]._contains(im)]) + for pset in self.psets)] + + # self in polar form + elif self.polar: + if other.is_zero: + # ignore undefined complex argument + return tfn[fuzzy_or(pset.args[0]._contains(S.Zero) + for pset in self.psets)] + if isTuple: + r, theta = other + else: + r, theta = Abs(other), arg(other) + if theta.is_real and theta.is_number: + # angles in psets are normalized to [0, 2pi) + theta %= 2*S.Pi + return tfn[fuzzy_or(fuzzy_and([ + pset.args[0]._contains(r), + pset.args[1]._contains(theta)]) + for pset in self.psets)] + + +class CartesianComplexRegion(ComplexRegion): + r""" + Set representing a square region of the complex plane. + + .. math:: Z = \{z \in \mathbb{C} \mid z = x + Iy, x \in [\operatorname{re}(z)], y \in [\operatorname{im}(z)]\} + + Examples + ======== + + >>> from sympy import ComplexRegion, I, Interval + >>> region = ComplexRegion(Interval(1, 3) * Interval(4, 6)) + >>> 2 + 5*I in region + True + >>> 5*I in region + False + + See also + ======== + + ComplexRegion + PolarComplexRegion + Complexes + """ + + polar = False + variables = symbols('x, y', cls=Dummy) + + def __new__(cls, sets): + + if sets == S.Reals*S.Reals: + return S.Complexes + + if all(_a.is_FiniteSet for _a in sets.args) and (len(sets.args) == 2): + + # ** ProductSet of FiniteSets in the Complex Plane. ** + # For Cases like ComplexRegion({2, 4}*{3}), It + # would return {2 + 3*I, 4 + 3*I} + + # FIXME: This should probably be handled with something like: + # return ImageSet(Lambda((x, y), x+I*y), sets).rewrite(FiniteSet) + complex_num = [] + for x in sets.args[0]: + for y in sets.args[1]: + complex_num.append(x + S.ImaginaryUnit*y) + return FiniteSet(*complex_num) + else: + return Set.__new__(cls, sets) + + @property + def expr(self): + x, y = self.variables + return x + S.ImaginaryUnit*y + + +class PolarComplexRegion(ComplexRegion): + r""" + Set representing a polar region of the complex plane. + + .. math:: Z = \{z \in \mathbb{C} \mid z = r\times (\cos(\theta) + I\sin(\theta)), r \in [\texttt{r}], \theta \in [\texttt{theta}]\} + + Examples + ======== + + >>> from sympy import ComplexRegion, Interval, oo, pi, I + >>> rset = Interval(0, oo) + >>> thetaset = Interval(0, pi) + >>> upper_half_plane = ComplexRegion(rset * thetaset, polar=True) + >>> 1 + I in upper_half_plane + True + >>> 1 - I in upper_half_plane + False + + See also + ======== + + ComplexRegion + CartesianComplexRegion + Complexes + + """ + + polar = True + variables = symbols('r, theta', cls=Dummy) + + def __new__(cls, sets): + + new_sets = [] + # sets is Union of ProductSets + if not sets.is_ProductSet: + for k in sets.args: + new_sets.append(k) + # sets is ProductSets + else: + new_sets.append(sets) + # Normalize input theta + for k, v in enumerate(new_sets): + new_sets[k] = ProductSet(v.args[0], + normalize_theta_set(v.args[1])) + sets = Union(*new_sets) + return Set.__new__(cls, sets) + + @property + def expr(self): + r, theta = self.variables + return r*(cos(theta) + S.ImaginaryUnit*sin(theta)) + + +class Complexes(CartesianComplexRegion, metaclass=Singleton): + """ + The :class:`Set` of all complex numbers + + Examples + ======== + + >>> from sympy import S, I + >>> S.Complexes + Complexes + >>> 1 + I in S.Complexes + True + + See also + ======== + + Reals + ComplexRegion + + """ + + is_empty = False + is_finite_set = False + + # Override property from superclass since Complexes has no args + @property + def sets(self): + return ProductSet(S.Reals, S.Reals) + + def __new__(cls): + return Set.__new__(cls) diff --git a/.venv/lib/python3.13/site-packages/sympy/sets/ordinals.py b/.venv/lib/python3.13/site-packages/sympy/sets/ordinals.py new file mode 100644 index 0000000000000000000000000000000000000000..cfe062354cfe58a4747998e51fa0d261e67576cc --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/sets/ordinals.py @@ -0,0 +1,282 @@ +from sympy.core import Basic, Integer +import operator + + +class OmegaPower(Basic): + """ + Represents ordinal exponential and multiplication terms one of the + building blocks of the :class:`Ordinal` class. + In ``OmegaPower(a, b)``, ``a`` represents exponent and ``b`` represents multiplicity. + """ + def __new__(cls, a, b): + if isinstance(b, int): + b = Integer(b) + if not isinstance(b, Integer) or b <= 0: + raise TypeError("multiplicity must be a positive integer") + + if not isinstance(a, Ordinal): + a = Ordinal.convert(a) + + return Basic.__new__(cls, a, b) + + @property + def exp(self): + return self.args[0] + + @property + def mult(self): + return self.args[1] + + def _compare_term(self, other, op): + if self.exp == other.exp: + return op(self.mult, other.mult) + else: + return op(self.exp, other.exp) + + def __eq__(self, other): + if not isinstance(other, OmegaPower): + try: + other = OmegaPower(0, other) + except TypeError: + return NotImplemented + return self.args == other.args + + def __hash__(self): + return Basic.__hash__(self) + + def __lt__(self, other): + if not isinstance(other, OmegaPower): + try: + other = OmegaPower(0, other) + except TypeError: + return NotImplemented + return self._compare_term(other, operator.lt) + + +class Ordinal(Basic): + """ + Represents ordinals in Cantor normal form. + + Internally, this class is just a list of instances of OmegaPower. + + Examples + ======== + >>> from sympy import Ordinal, OmegaPower + >>> from sympy.sets.ordinals import omega + >>> w = omega + >>> w.is_limit_ordinal + True + >>> Ordinal(OmegaPower(w + 1, 1), OmegaPower(3, 2)) + w**(w + 1) + w**3*2 + >>> 3 + w + w + >>> (w + 1) * w + w**2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Ordinal_arithmetic + """ + def __new__(cls, *terms): + obj = super().__new__(cls, *terms) + powers = [i.exp for i in obj.args] + if not all(powers[i] >= powers[i+1] for i in range(len(powers) - 1)): + raise ValueError("powers must be in decreasing order") + return obj + + @property + def terms(self): + return self.args + + @property + def leading_term(self): + if self == ord0: + raise ValueError("ordinal zero has no leading term") + return self.terms[0] + + @property + def trailing_term(self): + if self == ord0: + raise ValueError("ordinal zero has no trailing term") + return self.terms[-1] + + @property + def is_successor_ordinal(self): + try: + return self.trailing_term.exp == ord0 + except ValueError: + return False + + @property + def is_limit_ordinal(self): + try: + return not self.trailing_term.exp == ord0 + except ValueError: + return False + + @property + def degree(self): + return self.leading_term.exp + + @classmethod + def convert(cls, integer_value): + if integer_value == 0: + return ord0 + return Ordinal(OmegaPower(0, integer_value)) + + def __eq__(self, other): + if not isinstance(other, Ordinal): + try: + other = Ordinal.convert(other) + except TypeError: + return NotImplemented + return self.terms == other.terms + + def __hash__(self): + return hash(self.args) + + def __lt__(self, other): + if not isinstance(other, Ordinal): + try: + other = Ordinal.convert(other) + except TypeError: + return NotImplemented + for term_self, term_other in zip(self.terms, other.terms): + if term_self != term_other: + return term_self < term_other + return len(self.terms) < len(other.terms) + + def __le__(self, other): + return (self == other or self < other) + + def __gt__(self, other): + return not self <= other + + def __ge__(self, other): + return not self < other + + def __str__(self): + net_str = "" + plus_count = 0 + if self == ord0: + return 'ord0' + for i in self.terms: + if plus_count: + net_str += " + " + + if i.exp == ord0: + net_str += str(i.mult) + elif i.exp == 1: + net_str += 'w' + elif len(i.exp.terms) > 1 or i.exp.is_limit_ordinal: + net_str += 'w**(%s)'%i.exp + else: + net_str += 'w**%s'%i.exp + + if not i.mult == 1 and not i.exp == ord0: + net_str += '*%s'%i.mult + + plus_count += 1 + return(net_str) + + __repr__ = __str__ + + def __add__(self, other): + if not isinstance(other, Ordinal): + try: + other = Ordinal.convert(other) + except TypeError: + return NotImplemented + if other == ord0: + return self + a_terms = list(self.terms) + b_terms = list(other.terms) + r = len(a_terms) - 1 + b_exp = other.degree + while r >= 0 and a_terms[r].exp < b_exp: + r -= 1 + if r < 0: + terms = b_terms + elif a_terms[r].exp == b_exp: + sum_term = OmegaPower(b_exp, a_terms[r].mult + other.leading_term.mult) + terms = a_terms[:r] + [sum_term] + b_terms[1:] + else: + terms = a_terms[:r+1] + b_terms + return Ordinal(*terms) + + def __radd__(self, other): + if not isinstance(other, Ordinal): + try: + other = Ordinal.convert(other) + except TypeError: + return NotImplemented + return other + self + + def __mul__(self, other): + if not isinstance(other, Ordinal): + try: + other = Ordinal.convert(other) + except TypeError: + return NotImplemented + if ord0 in (self, other): + return ord0 + a_exp = self.degree + a_mult = self.leading_term.mult + summation = [] + if other.is_limit_ordinal: + for arg in other.terms: + summation.append(OmegaPower(a_exp + arg.exp, arg.mult)) + + else: + for arg in other.terms[:-1]: + summation.append(OmegaPower(a_exp + arg.exp, arg.mult)) + b_mult = other.trailing_term.mult + summation.append(OmegaPower(a_exp, a_mult*b_mult)) + summation += list(self.terms[1:]) + return Ordinal(*summation) + + def __rmul__(self, other): + if not isinstance(other, Ordinal): + try: + other = Ordinal.convert(other) + except TypeError: + return NotImplemented + return other * self + + def __pow__(self, other): + if not self == omega: + return NotImplemented + return Ordinal(OmegaPower(other, 1)) + + +class OrdinalZero(Ordinal): + """The ordinal zero. + + OrdinalZero can be imported as ``ord0``. + """ + pass + + +class OrdinalOmega(Ordinal): + """The ordinal omega which forms the base of all ordinals in cantor normal form. + + OrdinalOmega can be imported as ``omega``. + + Examples + ======== + + >>> from sympy.sets.ordinals import omega + >>> omega + omega + w*2 + """ + def __new__(cls): + return Ordinal.__new__(cls) + + @property + def terms(self): + return (OmegaPower(1, 1),) + + +ord0 = OrdinalZero() +omega = OrdinalOmega() diff --git a/.venv/lib/python3.13/site-packages/sympy/sets/powerset.py b/.venv/lib/python3.13/site-packages/sympy/sets/powerset.py new file mode 100644 index 0000000000000000000000000000000000000000..2eb3b41b9859281480bc9517a1cad0abe7a5683f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/sets/powerset.py @@ -0,0 +1,119 @@ +from sympy.core.decorators import _sympifyit +from sympy.core.parameters import global_parameters +from sympy.core.logic import fuzzy_bool +from sympy.core.singleton import S +from sympy.core.sympify import _sympify + +from .sets import Set, FiniteSet, SetKind + + +class PowerSet(Set): + r"""A symbolic object representing a power set. + + Parameters + ========== + + arg : Set + The set to take power of. + + evaluate : bool + The flag to control evaluation. + + If the evaluation is disabled for finite sets, it can take + advantage of using subset test as a membership test. + + Notes + ===== + + Power set `\mathcal{P}(S)` is defined as a set containing all the + subsets of `S`. + + If the set `S` is a finite set, its power set would have + `2^{\left| S \right|}` elements, where `\left| S \right|` denotes + the cardinality of `S`. + + Examples + ======== + + >>> from sympy import PowerSet, S, FiniteSet + + A power set of a finite set: + + >>> PowerSet(FiniteSet(1, 2, 3)) + PowerSet({1, 2, 3}) + + A power set of an empty set: + + >>> PowerSet(S.EmptySet) + PowerSet(EmptySet) + >>> PowerSet(PowerSet(S.EmptySet)) + PowerSet(PowerSet(EmptySet)) + + A power set of an infinite set: + + >>> PowerSet(S.Reals) + PowerSet(Reals) + + Evaluating the power set of a finite set to its explicit form: + + >>> PowerSet(FiniteSet(1, 2, 3)).rewrite(FiniteSet) + FiniteSet(EmptySet, {1}, {2}, {3}, {1, 2}, {1, 3}, {2, 3}, {1, 2, 3}) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Power_set + + .. [2] https://en.wikipedia.org/wiki/Axiom_of_power_set + """ + def __new__(cls, arg, evaluate=None): + if evaluate is None: + evaluate=global_parameters.evaluate + + arg = _sympify(arg) + + if not isinstance(arg, Set): + raise ValueError('{} must be a set.'.format(arg)) + + return super().__new__(cls, arg) + + @property + def arg(self): + return self.args[0] + + def _eval_rewrite_as_FiniteSet(self, *args, **kwargs): + arg = self.arg + if arg.is_FiniteSet: + return arg.powerset() + return None + + @_sympifyit('other', NotImplemented) + def _contains(self, other): + if not isinstance(other, Set): + return None + + return fuzzy_bool(self.arg.is_superset(other)) + + def _eval_is_subset(self, other): + if isinstance(other, PowerSet): + return self.arg.is_subset(other.arg) + + def __len__(self): + return 2 ** len(self.arg) + + def __iter__(self): + found = [S.EmptySet] + yield S.EmptySet + + for x in self.arg: + temp = [] + x = FiniteSet(x) + for y in found: + new = x + y + yield new + temp.append(new) + found.extend(temp) + + @property + def kind(self): + return SetKind(self.arg.kind) diff --git a/.venv/lib/python3.13/site-packages/sympy/sets/setexpr.py b/.venv/lib/python3.13/site-packages/sympy/sets/setexpr.py new file mode 100644 index 0000000000000000000000000000000000000000..94d77d5293617a620b70a945888987ce6cc61157 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/sets/setexpr.py @@ -0,0 +1,97 @@ +from sympy.core import Expr +from sympy.core.decorators import call_highest_priority, _sympifyit +from .fancysets import ImageSet +from .sets import set_add, set_sub, set_mul, set_div, set_pow, set_function + + +class SetExpr(Expr): + """An expression that can take on values of a set. + + Examples + ======== + + >>> from sympy import Interval, FiniteSet + >>> from sympy.sets.setexpr import SetExpr + + >>> a = SetExpr(Interval(0, 5)) + >>> b = SetExpr(FiniteSet(1, 10)) + >>> (a + b).set + Union(Interval(1, 6), Interval(10, 15)) + >>> (2*a + b).set + Interval(1, 20) + """ + _op_priority = 11.0 + + def __new__(cls, setarg): + return Expr.__new__(cls, setarg) + + set = property(lambda self: self.args[0]) + + def _latex(self, printer): + return r"SetExpr\left({}\right)".format(printer._print(self.set)) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__radd__') + def __add__(self, other): + return _setexpr_apply_operation(set_add, self, other) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__add__') + def __radd__(self, other): + return _setexpr_apply_operation(set_add, other, self) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__rmul__') + def __mul__(self, other): + return _setexpr_apply_operation(set_mul, self, other) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__mul__') + def __rmul__(self, other): + return _setexpr_apply_operation(set_mul, other, self) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__rsub__') + def __sub__(self, other): + return _setexpr_apply_operation(set_sub, self, other) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__sub__') + def __rsub__(self, other): + return _setexpr_apply_operation(set_sub, other, self) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__rpow__') + def __pow__(self, other): + return _setexpr_apply_operation(set_pow, self, other) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__pow__') + def __rpow__(self, other): + return _setexpr_apply_operation(set_pow, other, self) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__rtruediv__') + def __truediv__(self, other): + return _setexpr_apply_operation(set_div, self, other) + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__truediv__') + def __rtruediv__(self, other): + return _setexpr_apply_operation(set_div, other, self) + + def _eval_func(self, func): + # TODO: this could be implemented straight into `imageset`: + res = set_function(func, self.set) + if res is None: + return SetExpr(ImageSet(func, self.set)) + return SetExpr(res) + + +def _setexpr_apply_operation(op, x, y): + if isinstance(x, SetExpr): + x = x.set + if isinstance(y, SetExpr): + y = y.set + out = op(x, y) + return SetExpr(out) diff --git a/.venv/lib/python3.13/site-packages/sympy/sets/sets.py b/.venv/lib/python3.13/site-packages/sympy/sets/sets.py new file mode 100644 index 0000000000000000000000000000000000000000..3c85ce87c515cfd4520dcc6b9265fe76d8c6163f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/sets/sets.py @@ -0,0 +1,2804 @@ +from __future__ import annotations + +from typing import Any, Callable, TYPE_CHECKING, overload +from functools import reduce +from collections import defaultdict +from collections.abc import Mapping, Iterable +import inspect + +from sympy.core.kind import Kind, UndefinedKind, NumberKind +from sympy.core.basic import Basic +from sympy.core.containers import Tuple, TupleKind +from sympy.core.decorators import sympify_method_args, sympify_return +from sympy.core.evalf import EvalfMixin +from sympy.core.expr import Expr +from sympy.core.function import Lambda +from sympy.core.logic import (FuzzyBool, fuzzy_bool, fuzzy_or, fuzzy_and, + fuzzy_not) +from sympy.core.numbers import Float, Integer +from sympy.core.operations import LatticeOp +from sympy.core.parameters import global_parameters +from sympy.core.relational import Eq, Ne, is_lt +from sympy.core.singleton import Singleton, S +from sympy.core.sorting import ordered +from sympy.core.symbol import symbols, Symbol, Dummy, uniquely_named_symbol +from sympy.core.sympify import _sympify, sympify, _sympy_converter +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.miscellaneous import Max, Min +from sympy.logic.boolalg import And, Or, Not, Xor, true, false +from sympy.utilities.decorator import deprecated +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.iterables import (iproduct, sift, roundrobin, iterable, + subsets) +from sympy.utilities.misc import func_name, filldedent + +from mpmath import mpi, mpf + +from mpmath.libmp.libmpf import prec_to_dps + + +tfn = defaultdict(lambda: None, { + True: S.true, + S.true: S.true, + False: S.false, + S.false: S.false}) + + +@sympify_method_args +class Set(Basic, EvalfMixin): + """ + The base class for any kind of set. + + Explanation + =========== + + This is not meant to be used directly as a container of items. It does not + behave like the builtin ``set``; see :class:`FiniteSet` for that. + + Real intervals are represented by the :class:`Interval` class and unions of + sets by the :class:`Union` class. The empty set is represented by the + :class:`EmptySet` class and available as a singleton as ``S.EmptySet``. + """ + + __slots__: tuple[()] = () + + is_number = False + is_iterable = False + is_interval = False + + is_FiniteSet = False + is_Interval = False + is_ProductSet = False + is_Union = False + is_Intersection: FuzzyBool = None + is_UniversalSet: FuzzyBool = None + is_Complement: FuzzyBool = None + is_ComplexRegion = False + + is_empty: FuzzyBool = None + is_finite_set: FuzzyBool = None + + @property # type: ignore + @deprecated( + """ + The is_EmptySet attribute of Set objects is deprecated. + Use 's is S.EmptySet" or 's.is_empty' instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-is-emptyset", + ) + def is_EmptySet(self): + return None + + if TYPE_CHECKING: + + def __new__(cls, *args: Basic | complex) -> Set: + ... + + @overload # type: ignore + def subs(self, arg1: Mapping[Basic | complex, Set | complex], arg2: None=None) -> Set: ... + @overload + def subs(self, arg1: Iterable[tuple[Basic | complex, Set | complex]], arg2: None=None, **kwargs: Any) -> Set: ... + @overload + def subs(self, arg1: Set | complex, arg2: Set | complex) -> Set: ... + @overload + def subs(self, arg1: Mapping[Basic | complex, Basic | complex], arg2: None=None, **kwargs: Any) -> Basic: ... + @overload + def subs(self, arg1: Iterable[tuple[Basic | complex, Basic | complex]], arg2: None=None, **kwargs: Any) -> Basic: ... + @overload + def subs(self, arg1: Basic | complex, arg2: Basic | complex, **kwargs: Any) -> Basic: ... + + def subs(self, arg1: Mapping[Basic | complex, Basic | complex] | Basic | complex, # type: ignore + arg2: Basic | complex | None = None, **kwargs: Any) -> Basic: + ... + + def simplify(self, **kwargs) -> Set: + assert False + + def evalf(self, n: int = 15, subs: dict[Basic, Basic | float] | None = None, + maxn: int = 100, chop: bool = False, strict: bool = False, + quad: str | None = None, verbose: bool = False) -> Set: + ... + + n = evalf + + @staticmethod + def _infimum_key(expr): + """ + Return infimum (if possible) else S.Infinity. + """ + try: + infimum = expr.inf + assert infimum.is_comparable + infimum = infimum.evalf() # issue #18505 + except (NotImplementedError, + AttributeError, AssertionError, ValueError): + infimum = S.Infinity + return infimum + + def union(self, other): + """ + Returns the union of ``self`` and ``other``. + + Examples + ======== + + As a shortcut it is possible to use the ``+`` operator: + + >>> from sympy import Interval, FiniteSet + >>> Interval(0, 1).union(Interval(2, 3)) + Union(Interval(0, 1), Interval(2, 3)) + >>> Interval(0, 1) + Interval(2, 3) + Union(Interval(0, 1), Interval(2, 3)) + >>> Interval(1, 2, True, True) + FiniteSet(2, 3) + Union({3}, Interval.Lopen(1, 2)) + + Similarly it is possible to use the ``-`` operator for set differences: + + >>> Interval(0, 2) - Interval(0, 1) + Interval.Lopen(1, 2) + >>> Interval(1, 3) - FiniteSet(2) + Union(Interval.Ropen(1, 2), Interval.Lopen(2, 3)) + + """ + return Union(self, other) + + def intersect(self, other): + """ + Returns the intersection of 'self' and 'other'. + + Examples + ======== + + >>> from sympy import Interval + + >>> Interval(1, 3).intersect(Interval(1, 2)) + Interval(1, 2) + + >>> from sympy import imageset, Lambda, symbols, S + >>> n, m = symbols('n m') + >>> a = imageset(Lambda(n, 2*n), S.Integers) + >>> a.intersect(imageset(Lambda(m, 2*m + 1), S.Integers)) + EmptySet + + """ + return Intersection(self, other) + + def intersection(self, other): + """ + Alias for :meth:`intersect()` + """ + return self.intersect(other) + + def is_disjoint(self, other): + """ + Returns True if ``self`` and ``other`` are disjoint. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 2).is_disjoint(Interval(1, 2)) + False + >>> Interval(0, 2).is_disjoint(Interval(3, 4)) + True + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Disjoint_sets + """ + return self.intersect(other) == S.EmptySet + + def isdisjoint(self, other): + """ + Alias for :meth:`is_disjoint()` + """ + return self.is_disjoint(other) + + def complement(self, universe): + r""" + The complement of 'self' w.r.t the given universe. + + Examples + ======== + + >>> from sympy import Interval, S + >>> Interval(0, 1).complement(S.Reals) + Union(Interval.open(-oo, 0), Interval.open(1, oo)) + + >>> Interval(0, 1).complement(S.UniversalSet) + Complement(UniversalSet, Interval(0, 1)) + + """ + return Complement(universe, self) + + def _complement(self, other): + # this behaves as other - self + if isinstance(self, ProductSet) and isinstance(other, ProductSet): + # If self and other are disjoint then other - self == self + if len(self.sets) != len(other.sets): + return other + + # There can be other ways to represent this but this gives: + # (A x B) - (C x D) = ((A - C) x B) U (A x (B - D)) + overlaps = [] + pairs = list(zip(self.sets, other.sets)) + for n in range(len(pairs)): + sets = (o if i != n else o-s for i, (s, o) in enumerate(pairs)) + overlaps.append(ProductSet(*sets)) + return Union(*overlaps) + + elif isinstance(other, Interval): + if isinstance(self, (Interval, FiniteSet)): + return Intersection(other, self.complement(S.Reals)) + + elif isinstance(other, Union): + return Union(*(o - self for o in other.args)) + + elif isinstance(other, Complement): + return Complement(other.args[0], Union(other.args[1], self), evaluate=False) + + elif other is S.EmptySet: + return S.EmptySet + + elif isinstance(other, FiniteSet): + sifted = sift(other, lambda x: fuzzy_bool(self.contains(x))) + # ignore those that are contained in self + return Union(FiniteSet(*(sifted[False])), + Complement(FiniteSet(*(sifted[None])), self, evaluate=False) + if sifted[None] else S.EmptySet) + + def symmetric_difference(self, other): + """ + Returns symmetric difference of ``self`` and ``other``. + + Examples + ======== + + >>> from sympy import Interval, S + >>> Interval(1, 3).symmetric_difference(S.Reals) + Union(Interval.open(-oo, 1), Interval.open(3, oo)) + >>> Interval(1, 10).symmetric_difference(S.Reals) + Union(Interval.open(-oo, 1), Interval.open(10, oo)) + + >>> from sympy import S, EmptySet + >>> S.Reals.symmetric_difference(EmptySet) + Reals + + References + ========== + .. [1] https://en.wikipedia.org/wiki/Symmetric_difference + + """ + return SymmetricDifference(self, other) + + def _symmetric_difference(self, other): + return Union(Complement(self, other), Complement(other, self)) + + @property + def inf(self): + """ + The infimum of ``self``. + + Examples + ======== + + >>> from sympy import Interval, Union + >>> Interval(0, 1).inf + 0 + >>> Union(Interval(0, 1), Interval(2, 3)).inf + 0 + + """ + return self._inf + + @property + def _inf(self): + raise NotImplementedError("(%s)._inf" % self) + + @property + def sup(self): + """ + The supremum of ``self``. + + Examples + ======== + + >>> from sympy import Interval, Union + >>> Interval(0, 1).sup + 1 + >>> Union(Interval(0, 1), Interval(2, 3)).sup + 3 + + """ + return self._sup + + @property + def _sup(self): + raise NotImplementedError("(%s)._sup" % self) + + def contains(self, other): + """ + Returns a SymPy value indicating whether ``other`` is contained + in ``self``: ``true`` if it is, ``false`` if it is not, else + an unevaluated ``Contains`` expression (or, as in the case of + ConditionSet and a union of FiniteSet/Intervals, an expression + indicating the conditions for containment). + + Examples + ======== + + >>> from sympy import Interval, S + >>> from sympy.abc import x + + >>> Interval(0, 1).contains(0.5) + True + + As a shortcut it is possible to use the ``in`` operator, but that + will raise an error unless an affirmative true or false is not + obtained. + + >>> Interval(0, 1).contains(x) + (0 <= x) & (x <= 1) + >>> x in Interval(0, 1) + Traceback (most recent call last): + ... + TypeError: did not evaluate to a bool: None + + The result of 'in' is a bool, not a SymPy value + + >>> 1 in Interval(0, 2) + True + >>> _ is S.true + False + """ + from .contains import Contains + other = sympify(other, strict=True) + + c = self._contains(other) + if isinstance(c, Contains): + return c + if c is None: + return Contains(other, self, evaluate=False) + b = tfn[c] + if b is None: + return c + return b + + def _contains(self, other): + """Test if ``other`` is an element of the set ``self``. + + This is an internal method that is expected to be overridden by + subclasses of ``Set`` and will be called by the public + :func:`Set.contains` method or the :class:`Contains` expression. + + Parameters + ========== + + other: Sympified :class:`Basic` instance + The object whose membership in ``self`` is to be tested. + + Returns + ======= + + Symbolic :class:`Boolean` or ``None``. + + A return value of ``None`` indicates that it is unknown whether + ``other`` is contained in ``self``. Returning ``None`` from here + ensures that ``self.contains(other)`` or ``Contains(self, other)`` will + return an unevaluated :class:`Contains` expression. + + If not ``None`` then the returned value is a :class:`Boolean` that is + logically equivalent to the statement that ``other`` is an element of + ``self``. Usually this would be either ``S.true`` or ``S.false`` but + not always. + """ + raise NotImplementedError(f"{type(self).__name__}._contains") + + def is_subset(self, other): + """ + Returns True if ``self`` is a subset of ``other``. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 0.5).is_subset(Interval(0, 1)) + True + >>> Interval(0, 1).is_subset(Interval(0, 1, left_open=True)) + False + + """ + if not isinstance(other, Set): + raise ValueError("Unknown argument '%s'" % other) + + # Handle the trivial cases + if self == other: + return True + is_empty = self.is_empty + if is_empty is True: + return True + elif fuzzy_not(is_empty) and other.is_empty: + return False + if self.is_finite_set is False and other.is_finite_set: + return False + + # Dispatch on subclass rules + ret = self._eval_is_subset(other) + if ret is not None: + return ret + ret = other._eval_is_superset(self) + if ret is not None: + return ret + + # Use pairwise rules from multiple dispatch + from sympy.sets.handlers.issubset import is_subset_sets + ret = is_subset_sets(self, other) + if ret is not None: + return ret + + # Fall back on computing the intersection + # XXX: We shouldn't do this. A query like this should be handled + # without evaluating new Set objects. It should be the other way round + # so that the intersect method uses is_subset for evaluation. + if self.intersect(other) == self: + return True + + def _eval_is_subset(self, other): + '''Returns a fuzzy bool for whether self is a subset of other.''' + return None + + def _eval_is_superset(self, other): + '''Returns a fuzzy bool for whether self is a subset of other.''' + return None + + # This should be deprecated: + def issubset(self, other): + """ + Alias for :meth:`is_subset()` + """ + return self.is_subset(other) + + def is_proper_subset(self, other): + """ + Returns True if ``self`` is a proper subset of ``other``. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 0.5).is_proper_subset(Interval(0, 1)) + True + >>> Interval(0, 1).is_proper_subset(Interval(0, 1)) + False + + """ + if isinstance(other, Set): + return self != other and self.is_subset(other) + else: + raise ValueError("Unknown argument '%s'" % other) + + def is_superset(self, other): + """ + Returns True if ``self`` is a superset of ``other``. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 0.5).is_superset(Interval(0, 1)) + False + >>> Interval(0, 1).is_superset(Interval(0, 1, left_open=True)) + True + + """ + if isinstance(other, Set): + return other.is_subset(self) + else: + raise ValueError("Unknown argument '%s'" % other) + + # This should be deprecated: + def issuperset(self, other): + """ + Alias for :meth:`is_superset()` + """ + return self.is_superset(other) + + def is_proper_superset(self, other): + """ + Returns True if ``self`` is a proper superset of ``other``. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 1).is_proper_superset(Interval(0, 0.5)) + True + >>> Interval(0, 1).is_proper_superset(Interval(0, 1)) + False + + """ + if isinstance(other, Set): + return self != other and self.is_superset(other) + else: + raise ValueError("Unknown argument '%s'" % other) + + def _eval_powerset(self): + from .powerset import PowerSet + return PowerSet(self) + + def powerset(self): + """ + Find the Power set of ``self``. + + Examples + ======== + + >>> from sympy import EmptySet, FiniteSet, Interval + + A power set of an empty set: + + >>> A = EmptySet + >>> A.powerset() + {EmptySet} + + A power set of a finite set: + + >>> A = FiniteSet(1, 2) + >>> a, b, c = FiniteSet(1), FiniteSet(2), FiniteSet(1, 2) + >>> A.powerset() == FiniteSet(a, b, c, EmptySet) + True + + A power set of an interval: + + >>> Interval(1, 2).powerset() + PowerSet(Interval(1, 2)) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Power_set + + """ + return self._eval_powerset() + + @property + def measure(self): + """ + The (Lebesgue) measure of ``self``. + + Examples + ======== + + >>> from sympy import Interval, Union + >>> Interval(0, 1).measure + 1 + >>> Union(Interval(0, 1), Interval(2, 3)).measure + 2 + + """ + return self._measure + + @property + def kind(self): + """ + The kind of a Set + + Explanation + =========== + + Any :class:`Set` will have kind :class:`SetKind` which is + parametrised by the kind of the elements of the set. For example + most sets are sets of numbers and will have kind + ``SetKind(NumberKind)``. If elements of sets are different in kind than + their kind will ``SetKind(UndefinedKind)``. See + :class:`sympy.core.kind.Kind` for an explanation of the kind system. + + Examples + ======== + + >>> from sympy import Interval, Matrix, FiniteSet, EmptySet, ProductSet, PowerSet + + >>> FiniteSet(Matrix([1, 2])).kind + SetKind(MatrixKind(NumberKind)) + + >>> Interval(1, 2).kind + SetKind(NumberKind) + + >>> EmptySet.kind + SetKind() + + A :class:`sympy.sets.powerset.PowerSet` is a set of sets: + + >>> PowerSet({1, 2, 3}).kind + SetKind(SetKind(NumberKind)) + + A :class:`ProductSet` represents the set of tuples of elements of + other sets. Its kind is :class:`sympy.core.containers.TupleKind` + parametrised by the kinds of the elements of those sets: + + >>> p = ProductSet(FiniteSet(1, 2), FiniteSet(3, 4)) + >>> list(p) + [(1, 3), (2, 3), (1, 4), (2, 4)] + >>> p.kind + SetKind(TupleKind(NumberKind, NumberKind)) + + When all elements of the set do not have same kind, the kind + will be returned as ``SetKind(UndefinedKind)``: + + >>> FiniteSet(0, Matrix([1, 2])).kind + SetKind(UndefinedKind) + + The kind of the elements of a set are given by the ``element_kind`` + attribute of ``SetKind``: + + >>> Interval(1, 2).kind.element_kind + NumberKind + + See Also + ======== + + NumberKind + sympy.core.kind.UndefinedKind + sympy.core.containers.TupleKind + MatrixKind + sympy.matrices.expressions.sets.MatrixSet + sympy.sets.conditionset.ConditionSet + Rationals + Naturals + Integers + sympy.sets.fancysets.ImageSet + sympy.sets.fancysets.Range + sympy.sets.fancysets.ComplexRegion + sympy.sets.powerset.PowerSet + sympy.sets.sets.ProductSet + sympy.sets.sets.Interval + sympy.sets.sets.Union + sympy.sets.sets.Intersection + sympy.sets.sets.Complement + sympy.sets.sets.EmptySet + sympy.sets.sets.UniversalSet + sympy.sets.sets.FiniteSet + sympy.sets.sets.SymmetricDifference + sympy.sets.sets.DisjointUnion + """ + return self._kind() + + @property + def boundary(self): + """ + The boundary or frontier of a set. + + Explanation + =========== + + A point x is on the boundary of a set S if + + 1. x is in the closure of S. + I.e. Every neighborhood of x contains a point in S. + 2. x is not in the interior of S. + I.e. There does not exist an open set centered on x contained + entirely within S. + + There are the points on the outer rim of S. If S is open then these + points need not actually be contained within S. + + For example, the boundary of an interval is its start and end points. + This is true regardless of whether or not the interval is open. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 1).boundary + {0, 1} + >>> Interval(0, 1, True, False).boundary + {0, 1} + """ + return self._boundary + + @property + def is_open(self): + """ + Property method to check whether a set is open. + + Explanation + =========== + + A set is open if and only if it has an empty intersection with its + boundary. In particular, a subset A of the reals is open if and only + if each one of its points is contained in an open interval that is a + subset of A. + + Examples + ======== + >>> from sympy import S + >>> S.Reals.is_open + True + >>> S.Rationals.is_open + False + """ + return Intersection(self, self.boundary).is_empty + + @property + def is_closed(self): + """ + A property method to check whether a set is closed. + + Explanation + =========== + + A set is closed if its complement is an open set. The closedness of a + subset of the reals is determined with respect to R and its standard + topology. + + Examples + ======== + >>> from sympy import Interval + >>> Interval(0, 1).is_closed + True + """ + return self.boundary.is_subset(self) + + @property + def closure(self): + """ + Property method which returns the closure of a set. + The closure is defined as the union of the set itself and its + boundary. + + Examples + ======== + >>> from sympy import S, Interval + >>> S.Reals.closure + Reals + >>> Interval(0, 1).closure + Interval(0, 1) + """ + return self + self.boundary + + @property + def interior(self): + """ + Property method which returns the interior of a set. + The interior of a set S consists all points of S that do not + belong to the boundary of S. + + Examples + ======== + >>> from sympy import Interval + >>> Interval(0, 1).interior + Interval.open(0, 1) + >>> Interval(0, 1).boundary.interior + EmptySet + """ + return self - self.boundary + + @property + def _boundary(self): + raise NotImplementedError() + + @property + def _measure(self): + raise NotImplementedError("(%s)._measure" % self) + + def _kind(self): + return SetKind(UndefinedKind) + + def _eval_evalf(self, prec): + dps = prec_to_dps(prec) + return self.func(*[arg.evalf(n=dps) for arg in self.args]) + + @sympify_return([('other', 'Set')], NotImplemented) + def __add__(self, other): + return self.union(other) + + @sympify_return([('other', 'Set')], NotImplemented) + def __or__(self, other): + return self.union(other) + + @sympify_return([('other', 'Set')], NotImplemented) + def __and__(self, other): + return self.intersect(other) + + @sympify_return([('other', 'Set')], NotImplemented) + def __mul__(self, other): + return ProductSet(self, other) + + @sympify_return([('other', 'Set')], NotImplemented) + def __xor__(self, other): + return SymmetricDifference(self, other) + + @sympify_return([('exp', Expr)], NotImplemented) + def __pow__(self, exp): + if not (exp.is_Integer and exp >= 0): + raise ValueError("%s: Exponent must be a positive Integer" % exp) + return ProductSet(*[self]*exp) + + @sympify_return([('other', 'Set')], NotImplemented) + def __sub__(self, other): + return Complement(self, other) + + def __contains__(self, other): + other = _sympify(other) + c = self._contains(other) + b = tfn[c] + if b is None: + # x in y must evaluate to T or F; to entertain a None + # result with Set use y.contains(x) + raise TypeError('did not evaluate to a bool: %r' % c) + return b + + +class ProductSet(Set): + """ + Represents a Cartesian Product of Sets. + + Explanation + =========== + + Returns a Cartesian product given several sets as either an iterable + or individual arguments. + + Can use ``*`` operator on any sets for convenient shorthand. + + Examples + ======== + + >>> from sympy import Interval, FiniteSet, ProductSet + >>> I = Interval(0, 5); S = FiniteSet(1, 2, 3) + >>> ProductSet(I, S) + ProductSet(Interval(0, 5), {1, 2, 3}) + + >>> (2, 2) in ProductSet(I, S) + True + + >>> Interval(0, 1) * Interval(0, 1) # The unit square + ProductSet(Interval(0, 1), Interval(0, 1)) + + >>> coin = FiniteSet('H', 'T') + >>> set(coin**2) + {(H, H), (H, T), (T, H), (T, T)} + + The Cartesian product is not commutative or associative e.g.: + + >>> I*S == S*I + False + >>> (I*I)*I == I*(I*I) + False + + Notes + ===== + + - Passes most operations down to the argument sets + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Cartesian_product + """ + is_ProductSet = True + + def __new__(cls, *sets, **assumptions): + if len(sets) == 1 and iterable(sets[0]) and not isinstance(sets[0], (Set, set)): + sympy_deprecation_warning( + """ +ProductSet(iterable) is deprecated. Use ProductSet(*iterable) instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-productset-iterable", + ) + sets = tuple(sets[0]) + + sets = [sympify(s) for s in sets] + + if not all(isinstance(s, Set) for s in sets): + raise TypeError("Arguments to ProductSet should be of type Set") + + # Nullary product of sets is *not* the empty set + if len(sets) == 0: + return FiniteSet(()) + + if S.EmptySet in sets: + return S.EmptySet + + return Basic.__new__(cls, *sets, **assumptions) + + @property + def sets(self): + return self.args + + def flatten(self): + def _flatten(sets): + for s in sets: + if s.is_ProductSet: + yield from _flatten(s.sets) + else: + yield s + return ProductSet(*_flatten(self.sets)) + + + + def _contains(self, element): + """ + ``in`` operator for ProductSets. + + Examples + ======== + + >>> from sympy import Interval + >>> (2, 3) in Interval(0, 5) * Interval(0, 5) + True + + >>> (10, 10) in Interval(0, 5) * Interval(0, 5) + False + + Passes operation on to constituent sets + """ + if element.is_Symbol: + return None + + if not isinstance(element, Tuple) or len(element) != len(self.sets): + return S.false + + return And(*[s.contains(e) for s, e in zip(self.sets, element)]) + + def as_relational(self, *symbols): + symbols = [_sympify(s) for s in symbols] + if len(symbols) != len(self.sets) or not all( + i.is_Symbol for i in symbols): + raise ValueError( + 'number of symbols must match the number of sets') + return And(*[s.as_relational(i) for s, i in zip(self.sets, symbols)]) + + @property + def _boundary(self): + return Union(*(ProductSet(*(b + b.boundary if i != j else b.boundary + for j, b in enumerate(self.sets))) + for i, a in enumerate(self.sets))) + + @property + def is_iterable(self): + """ + A property method which tests whether a set is iterable or not. + Returns True if set is iterable, otherwise returns False. + + Examples + ======== + + >>> from sympy import FiniteSet, Interval + >>> I = Interval(0, 1) + >>> A = FiniteSet(1, 2, 3, 4, 5) + >>> I.is_iterable + False + >>> A.is_iterable + True + + """ + return all(set.is_iterable for set in self.sets) + + def __iter__(self): + """ + A method which implements is_iterable property method. + If self.is_iterable returns True (both constituent sets are iterable), + then return the Cartesian Product. Otherwise, raise TypeError. + """ + return iproduct(*self.sets) + + @property + def is_empty(self): + return fuzzy_or(s.is_empty for s in self.sets) + + @property + def is_finite_set(self): + all_finite = fuzzy_and(s.is_finite_set for s in self.sets) + return fuzzy_or([self.is_empty, all_finite]) + + @property + def _measure(self): + measure = 1 + for s in self.sets: + measure *= s.measure + return measure + + def _kind(self): + return SetKind(TupleKind(*(i.kind.element_kind for i in self.args))) + + def __len__(self): + return reduce(lambda a, b: a*b, (len(s) for s in self.args)) + + def __bool__(self): + return all(self.sets) + + +class Interval(Set): + """ + Represents a real interval as a Set. + + Usage: + Returns an interval with end points ``start`` and ``end``. + + For ``left_open=True`` (default ``left_open`` is ``False``) the interval + will be open on the left. Similarly, for ``right_open=True`` the interval + will be open on the right. + + Examples + ======== + + >>> from sympy import Symbol, Interval + >>> Interval(0, 1) + Interval(0, 1) + >>> Interval.Ropen(0, 1) + Interval.Ropen(0, 1) + >>> Interval.Ropen(0, 1) + Interval.Ropen(0, 1) + >>> Interval.Lopen(0, 1) + Interval.Lopen(0, 1) + >>> Interval.open(0, 1) + Interval.open(0, 1) + + >>> a = Symbol('a', real=True) + >>> Interval(0, a) + Interval(0, a) + + Notes + ===== + - Only real end points are supported + - ``Interval(a, b)`` with $a > b$ will return the empty set + - Use the ``evalf()`` method to turn an Interval into an mpmath + ``mpi`` interval instance + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Interval_%28mathematics%29 + """ + is_Interval = True + + def __new__(cls, start, end, left_open=False, right_open=False): + + start = _sympify(start) + end = _sympify(end) + left_open = _sympify(left_open) + right_open = _sympify(right_open) + + if not all(isinstance(a, (type(true), type(false))) + for a in [left_open, right_open]): + raise NotImplementedError( + "left_open and right_open can have only true/false values, " + "got %s and %s" % (left_open, right_open)) + + # Only allow real intervals + if fuzzy_not(fuzzy_and(i.is_extended_real for i in (start, end, end-start))): + raise ValueError("Non-real intervals are not supported") + + # evaluate if possible + if is_lt(end, start): + return S.EmptySet + elif (end - start).is_negative: + return S.EmptySet + + if end == start and (left_open or right_open): + return S.EmptySet + if end == start and not (left_open or right_open): + if start is S.Infinity or start is S.NegativeInfinity: + return S.EmptySet + return FiniteSet(end) + + # Make sure infinite interval end points are open. + if start is S.NegativeInfinity: + left_open = true + if end is S.Infinity: + right_open = true + if start == S.Infinity or end == S.NegativeInfinity: + return S.EmptySet + + return Basic.__new__(cls, start, end, left_open, right_open) + + @property + def start(self): + """ + The left end point of the interval. + + This property takes the same value as the ``inf`` property. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 1).start + 0 + + """ + return self._args[0] + + @property + def end(self): + """ + The right end point of the interval. + + This property takes the same value as the ``sup`` property. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 1).end + 1 + + """ + return self._args[1] + + @property + def left_open(self): + """ + True if interval is left-open. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 1, left_open=True).left_open + True + >>> Interval(0, 1, left_open=False).left_open + False + + """ + return self._args[2] + + @property + def right_open(self): + """ + True if interval is right-open. + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(0, 1, right_open=True).right_open + True + >>> Interval(0, 1, right_open=False).right_open + False + + """ + return self._args[3] + + @classmethod + def open(cls, a, b): + """Return an interval including neither boundary.""" + return cls(a, b, True, True) + + @classmethod + def Lopen(cls, a, b): + """Return an interval not including the left boundary.""" + return cls(a, b, True, False) + + @classmethod + def Ropen(cls, a, b): + """Return an interval not including the right boundary.""" + return cls(a, b, False, True) + + @property + def _inf(self): + return self.start + + @property + def _sup(self): + return self.end + + @property + def left(self): + return self.start + + @property + def right(self): + return self.end + + @property + def is_empty(self): + if self.left_open or self.right_open: + cond = self.start >= self.end # One/both bounds open + else: + cond = self.start > self.end # Both bounds closed + return fuzzy_bool(cond) + + @property + def is_finite_set(self): + return self.measure.is_zero + + def _complement(self, other): + if other == S.Reals: + a = Interval(S.NegativeInfinity, self.start, + True, not self.left_open) + b = Interval(self.end, S.Infinity, not self.right_open, True) + return Union(a, b) + + if isinstance(other, FiniteSet): + nums = [m for m in other.args if m.is_number] + if nums == []: + return None + + return Set._complement(self, other) + + @property + def _boundary(self): + finite_points = [p for p in (self.start, self.end) + if abs(p) != S.Infinity] + return FiniteSet(*finite_points) + + def _contains(self, other): + if (not isinstance(other, Expr) or other is S.NaN + or other.is_real is False or other.has(S.ComplexInfinity)): + # if an expression has zoo it will be zoo or nan + # and neither of those is real + return false + + if self.start is S.NegativeInfinity and self.end is S.Infinity: + if other.is_real is not None: + return tfn[other.is_real] + + d = Dummy() + return self.as_relational(d).subs(d, other) + + def as_relational(self, x): + """Rewrite an interval in terms of inequalities and logic operators.""" + x = sympify(x) + if self.right_open: + right = x < self.end + else: + right = x <= self.end + if self.left_open: + left = self.start < x + else: + left = self.start <= x + return And(left, right) + + @property + def _measure(self): + return self.end - self.start + + def _kind(self): + return SetKind(NumberKind) + + def to_mpi(self, prec=53): + return mpi(mpf(self.start._eval_evalf(prec)), + mpf(self.end._eval_evalf(prec))) + + def _eval_evalf(self, prec): + return Interval(self.left._evalf(prec), self.right._evalf(prec), + left_open=self.left_open, right_open=self.right_open) + + def _is_comparable(self, other): + is_comparable = self.start.is_comparable + is_comparable &= self.end.is_comparable + is_comparable &= other.start.is_comparable + is_comparable &= other.end.is_comparable + + return is_comparable + + @property + def is_left_unbounded(self): + """Return ``True`` if the left endpoint is negative infinity. """ + return self.left is S.NegativeInfinity or self.left == Float("-inf") + + @property + def is_right_unbounded(self): + """Return ``True`` if the right endpoint is positive infinity. """ + return self.right is S.Infinity or self.right == Float("+inf") + + def _eval_Eq(self, other): + if not isinstance(other, Interval): + if isinstance(other, FiniteSet): + return false + elif isinstance(other, Set): + return None + return false + + +class Union(Set, LatticeOp): + """ + Represents a union of sets as a :class:`Set`. + + Examples + ======== + + >>> from sympy import Union, Interval + >>> Union(Interval(1, 2), Interval(3, 4)) + Union(Interval(1, 2), Interval(3, 4)) + + The Union constructor will always try to merge overlapping intervals, + if possible. For example: + + >>> Union(Interval(1, 2), Interval(2, 3)) + Interval(1, 3) + + See Also + ======== + + Intersection + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Union_%28set_theory%29 + """ + is_Union = True + + @property + def identity(self): + return S.EmptySet + + @property + def zero(self): + return S.UniversalSet + + def __new__(cls, *args, **kwargs): + evaluate = kwargs.get('evaluate', global_parameters.evaluate) + + # flatten inputs to merge intersections and iterables + args = _sympify(args) + + # Reduce sets using known rules + if evaluate: + args = list(cls._new_args_filter(args)) + return simplify_union(args) + + args = list(ordered(args, Set._infimum_key)) + + obj = Basic.__new__(cls, *args) + obj._argset = frozenset(args) + return obj + + @property + def args(self): + return self._args + + def _complement(self, universe): + # DeMorgan's Law + return Intersection(s.complement(universe) for s in self.args) + + @property + def _inf(self): + # We use Min so that sup is meaningful in combination with symbolic + # interval end points. + return Min(*[set.inf for set in self.args]) + + @property + def _sup(self): + # We use Max so that sup is meaningful in combination with symbolic + # end points. + return Max(*[set.sup for set in self.args]) + + @property + def is_empty(self): + return fuzzy_and(set.is_empty for set in self.args) + + @property + def is_finite_set(self): + return fuzzy_and(set.is_finite_set for set in self.args) + + @property + def _measure(self): + # Measure of a union is the sum of the measures of the sets minus + # the sum of their pairwise intersections plus the sum of their + # triple-wise intersections minus ... etc... + + # Sets is a collection of intersections and a set of elementary + # sets which made up those intersections (called "sos" for set of sets) + # An example element might of this list might be: + # ( {A,B,C}, A.intersect(B).intersect(C) ) + + # Start with just elementary sets ( ({A}, A), ({B}, B), ... ) + # Then get and subtract ( ({A,B}, (A int B), ... ) while non-zero + sets = [(FiniteSet(s), s) for s in self.args] + measure = 0 + parity = 1 + while sets: + # Add up the measure of these sets and add or subtract it to total + measure += parity * sum(inter.measure for sos, inter in sets) + + # For each intersection in sets, compute the intersection with every + # other set not already part of the intersection. + sets = ((sos + FiniteSet(newset), newset.intersect(intersection)) + for sos, intersection in sets for newset in self.args + if newset not in sos) + + # Clear out sets with no measure + sets = [(sos, inter) for sos, inter in sets if inter.measure != 0] + + # Clear out duplicates + sos_list = [] + sets_list = [] + for _set in sets: + if _set[0] in sos_list: + continue + else: + sos_list.append(_set[0]) + sets_list.append(_set) + sets = sets_list + + # Flip Parity - next time subtract/add if we added/subtracted here + parity *= -1 + return measure + + def _kind(self): + kinds = tuple(arg.kind for arg in self.args if arg is not S.EmptySet) + if not kinds: + return SetKind() + elif all(i == kinds[0] for i in kinds): + return kinds[0] + else: + return SetKind(UndefinedKind) + + @property + def _boundary(self): + def boundary_of_set(i): + """ The boundary of set i minus interior of all other sets """ + b = self.args[i].boundary + for j, a in enumerate(self.args): + if j != i: + b = b - a.interior + return b + return Union(*map(boundary_of_set, range(len(self.args)))) + + def _contains(self, other): + return Or(*[s.contains(other) for s in self.args]) + + def is_subset(self, other): + return fuzzy_and(s.is_subset(other) for s in self.args) + + def as_relational(self, symbol): + """Rewrite a Union in terms of equalities and logic operators. """ + if (len(self.args) == 2 and + all(isinstance(i, Interval) for i in self.args)): + # optimization to give 3 args as (x > 1) & (x < 5) & Ne(x, 3) + # instead of as 4, ((1 <= x) & (x < 3)) | ((x <= 5) & (3 < x)) + # XXX: This should be ideally be improved to handle any number of + # intervals and also not to assume that the intervals are in any + # particular sorted order. + a, b = self.args + if a.sup == b.inf and a.right_open and b.left_open: + mincond = symbol > a.inf if a.left_open else symbol >= a.inf + maxcond = symbol < b.sup if b.right_open else symbol <= b.sup + necond = Ne(symbol, a.sup) + return And(necond, mincond, maxcond) + return Or(*[i.as_relational(symbol) for i in self.args]) + + @property + def is_iterable(self): + return all(arg.is_iterable for arg in self.args) + + def __iter__(self): + return roundrobin(*(iter(arg) for arg in self.args)) + + +class Intersection(Set, LatticeOp): + """ + Represents an intersection of sets as a :class:`Set`. + + Examples + ======== + + >>> from sympy import Intersection, Interval + >>> Intersection(Interval(1, 3), Interval(2, 4)) + Interval(2, 3) + + We often use the .intersect method + + >>> Interval(1,3).intersect(Interval(2,4)) + Interval(2, 3) + + See Also + ======== + + Union + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Intersection_%28set_theory%29 + """ + is_Intersection = True + + @property + def identity(self): + return S.UniversalSet + + @property + def zero(self): + return S.EmptySet + + def __new__(cls, *args , evaluate=None): + if evaluate is None: + evaluate = global_parameters.evaluate + + # flatten inputs to merge intersections and iterables + args = list(ordered(set(_sympify(args)))) + + # Reduce sets using known rules + if evaluate: + args = list(cls._new_args_filter(args)) + return simplify_intersection(args) + + args = list(ordered(args, Set._infimum_key)) + + obj = Basic.__new__(cls, *args) + obj._argset = frozenset(args) + return obj + + @property + def args(self): + return self._args + + @property + def is_iterable(self): + return any(arg.is_iterable for arg in self.args) + + @property + def is_finite_set(self): + if fuzzy_or(arg.is_finite_set for arg in self.args): + return True + + def _kind(self): + kinds = tuple(arg.kind for arg in self.args if arg is not S.UniversalSet) + if not kinds: + return SetKind(UndefinedKind) + elif all(i == kinds[0] for i in kinds): + return kinds[0] + else: + return SetKind() + + @property + def _inf(self): + raise NotImplementedError() + + @property + def _sup(self): + raise NotImplementedError() + + def _contains(self, other): + return And(*[set.contains(other) for set in self.args]) + + def __iter__(self): + sets_sift = sift(self.args, lambda x: x.is_iterable) + + completed = False + candidates = sets_sift[True] + sets_sift[None] + + finite_candidates, others = [], [] + for candidate in candidates: + length = None + try: + length = len(candidate) + except TypeError: + others.append(candidate) + + if length is not None: + finite_candidates.append(candidate) + finite_candidates.sort(key=len) + + for s in finite_candidates + others: + other_sets = set(self.args) - {s} + other = Intersection(*other_sets, evaluate=False) + completed = True + for x in s: + try: + if x in other: + yield x + except TypeError: + completed = False + if completed: + return + + if not completed: + if not candidates: + raise TypeError("None of the constituent sets are iterable") + raise TypeError( + "The computation had not completed because of the " + "undecidable set membership is found in every candidates.") + + @staticmethod + def _handle_finite_sets(args): + '''Simplify intersection of one or more FiniteSets and other sets''' + + # First separate the FiniteSets from the others + fs_args, others = sift(args, lambda x: x.is_FiniteSet, binary=True) + + # Let the caller handle intersection of non-FiniteSets + if not fs_args: + return + + # Convert to Python sets and build the set of all elements + fs_sets = [set(fs) for fs in fs_args] + all_elements = reduce(lambda a, b: a | b, fs_sets, set()) + + # Extract elements that are definitely in or definitely not in the + # intersection. Here we check contains for all of args. + definite = set() + for e in all_elements: + inall = fuzzy_and(s.contains(e) for s in args) + if inall is True: + definite.add(e) + if inall is not None: + for s in fs_sets: + s.discard(e) + + # At this point all elements in all of fs_sets are possibly in the + # intersection. In some cases this is because they are definitely in + # the intersection of the finite sets but it's not clear if they are + # members of others. We might have {m, n}, {m}, and Reals where we + # don't know if m or n is real. We want to remove n here but it is + # possibly in because it might be equal to m. So what we do now is + # extract the elements that are definitely in the remaining finite + # sets iteratively until we end up with {n}, {}. At that point if we + # get any empty set all remaining elements are discarded. + + fs_elements = reduce(lambda a, b: a | b, fs_sets, set()) + + # Need fuzzy containment testing + fs_symsets = [FiniteSet(*s) for s in fs_sets] + + while fs_elements: + for e in fs_elements: + infs = fuzzy_and(s.contains(e) for s in fs_symsets) + if infs is True: + definite.add(e) + if infs is not None: + for n, s in enumerate(fs_sets): + # Update Python set and FiniteSet + if e in s: + s.remove(e) + fs_symsets[n] = FiniteSet(*s) + fs_elements.remove(e) + break + # If we completed the for loop without removing anything we are + # done so quit the outer while loop + else: + break + + # If any of the sets of remainder elements is empty then we discard + # all of them for the intersection. + if not all(fs_sets): + fs_sets = [set()] + + # Here we fold back the definitely included elements into each fs. + # Since they are definitely included they must have been members of + # each FiniteSet to begin with. We could instead fold these in with a + # Union at the end to get e.g. {3}|({x}&{y}) rather than {3,x}&{3,y}. + if definite: + fs_sets = [fs | definite for fs in fs_sets] + + if fs_sets == [set()]: + return S.EmptySet + + sets = [FiniteSet(*s) for s in fs_sets] + + # Any set in others is redundant if it contains all the elements that + # are in the finite sets so we don't need it in the Intersection + all_elements = reduce(lambda a, b: a | b, fs_sets, set()) + is_redundant = lambda o: all(fuzzy_bool(o.contains(e)) for e in all_elements) + others = [o for o in others if not is_redundant(o)] + + if others: + rest = Intersection(*others) + # XXX: Maybe this shortcut should be at the beginning. For large + # FiniteSets it could much more efficient to process the other + # sets first... + if rest is S.EmptySet: + return S.EmptySet + # Flatten the Intersection + if rest.is_Intersection: + sets.extend(rest.args) + else: + sets.append(rest) + + if len(sets) == 1: + return sets[0] + else: + return Intersection(*sets, evaluate=False) + + def as_relational(self, symbol): + """Rewrite an Intersection in terms of equalities and logic operators""" + return And(*[set.as_relational(symbol) for set in self.args]) + + +class Complement(Set): + r"""Represents the set difference or relative complement of a set with + another set. + + $$A - B = \{x \in A \mid x \notin B\}$$ + + + Examples + ======== + + >>> from sympy import Complement, FiniteSet + >>> Complement(FiniteSet(0, 1, 2), FiniteSet(1)) + {0, 2} + + See Also + ========= + + Intersection, Union + + References + ========== + + .. [1] https://mathworld.wolfram.com/ComplementSet.html + """ + + is_Complement = True + + def __new__(cls, a, b, evaluate=True): + a, b = map(_sympify, (a, b)) + if evaluate: + return Complement.reduce(a, b) + + return Basic.__new__(cls, a, b) + + @staticmethod + def reduce(A, B): + """ + Simplify a :class:`Complement`. + + """ + if B == S.UniversalSet or A.is_subset(B): + return S.EmptySet + + if isinstance(B, Union): + return Intersection(*(s.complement(A) for s in B.args)) + + result = B._complement(A) + if result is not None: + return result + else: + return Complement(A, B, evaluate=False) + + def _contains(self, other): + A = self.args[0] + B = self.args[1] + return And(A.contains(other), Not(B.contains(other))) + + def as_relational(self, symbol): + """Rewrite a complement in terms of equalities and logic + operators""" + A, B = self.args + + A_rel = A.as_relational(symbol) + B_rel = Not(B.as_relational(symbol)) + + return And(A_rel, B_rel) + + def _kind(self): + return self.args[0].kind + + @property + def is_iterable(self): + if self.args[0].is_iterable: + return True + + @property + def is_finite_set(self): + A, B = self.args + a_finite = A.is_finite_set + if a_finite is True: + return True + elif a_finite is False and B.is_finite_set: + return False + + def __iter__(self): + A, B = self.args + for a in A: + if a not in B: + yield a + else: + continue + + +class EmptySet(Set, metaclass=Singleton): + """ + Represents the empty set. The empty set is available as a singleton + as ``S.EmptySet``. + + Examples + ======== + + >>> from sympy import S, Interval + >>> S.EmptySet + EmptySet + + >>> Interval(1, 2).intersect(S.EmptySet) + EmptySet + + See Also + ======== + + UniversalSet + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Empty_set + """ + is_empty = True + is_finite_set = True + is_FiniteSet = True + + @property # type: ignore + @deprecated( + """ + The is_EmptySet attribute of Set objects is deprecated. + Use 's is S.EmptySet" or 's.is_empty' instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-is-emptyset", + ) + def is_EmptySet(self): + return True + + @property + def _measure(self): + return 0 + + def _contains(self, other): + return false + + def as_relational(self, symbol): + return false + + def __len__(self): + return 0 + + def __iter__(self): + return iter([]) + + def _eval_powerset(self): + return FiniteSet(self) + + @property + def _boundary(self): + return self + + def _complement(self, other): + return other + + def _kind(self): + return SetKind() + + def _symmetric_difference(self, other): + return other + + +class UniversalSet(Set, metaclass=Singleton): + """ + Represents the set of all things. + The universal set is available as a singleton as ``S.UniversalSet``. + + Examples + ======== + + >>> from sympy import S, Interval + >>> S.UniversalSet + UniversalSet + + >>> Interval(1, 2).intersect(S.UniversalSet) + Interval(1, 2) + + See Also + ======== + + EmptySet + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Universal_set + """ + + is_UniversalSet = True + is_empty = False + is_finite_set = False + + def _complement(self, other): + return S.EmptySet + + def _symmetric_difference(self, other): + return other + + @property + def _measure(self): + return S.Infinity + + def _kind(self): + return SetKind(UndefinedKind) + + def _contains(self, other): + return true + + def as_relational(self, symbol): + return true + + @property + def _boundary(self): + return S.EmptySet + + +class FiniteSet(Set): + """ + Represents a finite set of Sympy expressions. + + Examples + ======== + + >>> from sympy import FiniteSet, Symbol, Interval, Naturals0 + >>> FiniteSet(1, 2, 3, 4) + {1, 2, 3, 4} + >>> 3 in FiniteSet(1, 2, 3, 4) + True + >>> FiniteSet(1, (1, 2), Symbol('x')) + {1, x, (1, 2)} + >>> FiniteSet(Interval(1, 2), Naturals0, {1, 2}) + FiniteSet({1, 2}, Interval(1, 2), Naturals0) + >>> members = [1, 2, 3, 4] + >>> f = FiniteSet(*members) + >>> f + {1, 2, 3, 4} + >>> f - FiniteSet(2) + {1, 3, 4} + >>> f + FiniteSet(2, 5) + {1, 2, 3, 4, 5} + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Finite_set + """ + is_FiniteSet = True + is_iterable = True + is_empty = False + is_finite_set = True + + def __new__(cls, *args, **kwargs): + evaluate = kwargs.get('evaluate', global_parameters.evaluate) + if evaluate: + args = list(map(sympify, args)) + + if len(args) == 0: + return S.EmptySet + else: + args = list(map(sympify, args)) + + # keep the form of the first canonical arg + dargs = {} + for i in reversed(list(ordered(args))): + if i.is_Symbol: + dargs[i] = i + else: + try: + dargs[i.as_dummy()] = i + except TypeError: + # e.g. i = class without args like `Interval` + dargs[i] = i + _args_set = set(dargs.values()) + args = list(ordered(_args_set, Set._infimum_key)) + obj = Basic.__new__(cls, *args) + obj._args_set = _args_set + return obj + + + def __iter__(self): + return iter(self.args) + + def _complement(self, other): + if isinstance(other, Interval): + # Splitting in sub-intervals is only done for S.Reals; + # other cases that need splitting will first pass through + # Set._complement(). + nums, syms = [], [] + for m in self.args: + if m.is_number and m.is_real: + nums.append(m) + elif m.is_real == False: + pass # drop non-reals + else: + syms.append(m) # various symbolic expressions + if other == S.Reals and nums != []: + nums.sort() + intervals = [] # Build up a list of intervals between the elements + intervals += [Interval(S.NegativeInfinity, nums[0], True, True)] + for a, b in zip(nums[:-1], nums[1:]): + intervals.append(Interval(a, b, True, True)) # both open + intervals.append(Interval(nums[-1], S.Infinity, True, True)) + if syms != []: + return Complement(Union(*intervals, evaluate=False), + FiniteSet(*syms), evaluate=False) + else: + return Union(*intervals, evaluate=False) + elif nums == []: # no splitting necessary or possible: + if syms: + return Complement(other, FiniteSet(*syms), evaluate=False) + else: + return other + + elif isinstance(other, FiniteSet): + unk = [] + for i in self: + c = sympify(other.contains(i)) + if c is not S.true and c is not S.false: + unk.append(i) + unk = FiniteSet(*unk) + if unk == self: + return + not_true = [] + for i in other: + c = sympify(self.contains(i)) + if c is not S.true: + not_true.append(i) + return Complement(FiniteSet(*not_true), unk) + + return Set._complement(self, other) + + def _contains(self, other): + """ + Tests whether an element, other, is in the set. + + Explanation + =========== + + The actual test is for mathematical equality (as opposed to + syntactical equality). In the worst case all elements of the + set must be checked. + + Examples + ======== + + >>> from sympy import FiniteSet + >>> 1 in FiniteSet(1, 2) + True + >>> 5 in FiniteSet(1, 2) + False + + """ + if other in self._args_set: + return S.true + else: + # evaluate=True is needed to override evaluate=False context; + # we need Eq to do the evaluation + return Or(*[Eq(e, other, evaluate=True) for e in self.args]) + + def _eval_is_subset(self, other): + return fuzzy_and(other._contains(e) for e in self.args) + + @property + def _boundary(self): + return self + + @property + def _inf(self): + return Min(*self) + + @property + def _sup(self): + return Max(*self) + + @property + def measure(self): + return 0 + + def _kind(self): + if not self.args: + return SetKind() + elif all(i.kind == self.args[0].kind for i in self.args): + return SetKind(self.args[0].kind) + else: + return SetKind(UndefinedKind) + + def __len__(self): + return len(self.args) + + def as_relational(self, symbol): + """Rewrite a FiniteSet in terms of equalities and logic operators. """ + return Or(*[Eq(symbol, elem) for elem in self]) + + def compare(self, other): + return (hash(self) - hash(other)) + + def _eval_evalf(self, prec): + dps = prec_to_dps(prec) + return FiniteSet(*[elem.evalf(n=dps) for elem in self]) + + def _eval_simplify(self, **kwargs): + from sympy.simplify import simplify + return FiniteSet(*[simplify(elem, **kwargs) for elem in self]) + + @property + def _sorted_args(self): + return self.args + + def _eval_powerset(self): + return self.func(*[self.func(*s) for s in subsets(self.args)]) + + def _eval_rewrite_as_PowerSet(self, *args, **kwargs): + """Rewriting method for a finite set to a power set.""" + from .powerset import PowerSet + + is2pow = lambda n: bool(n and not n & (n - 1)) + if not is2pow(len(self)): + return None + + fs_test = lambda arg: isinstance(arg, Set) and arg.is_FiniteSet + if not all(fs_test(arg) for arg in args): + return None + + biggest = max(args, key=len) + for arg in subsets(biggest.args): + arg_set = FiniteSet(*arg) + if arg_set not in args: + return None + return PowerSet(biggest) + + def __ge__(self, other): + if not isinstance(other, Set): + raise TypeError("Invalid comparison of set with %s" % func_name(other)) + return other.is_subset(self) + + def __gt__(self, other): + if not isinstance(other, Set): + raise TypeError("Invalid comparison of set with %s" % func_name(other)) + return self.is_proper_superset(other) + + def __le__(self, other): + if not isinstance(other, Set): + raise TypeError("Invalid comparison of set with %s" % func_name(other)) + return self.is_subset(other) + + def __lt__(self, other): + if not isinstance(other, Set): + raise TypeError("Invalid comparison of set with %s" % func_name(other)) + return self.is_proper_subset(other) + + def __eq__(self, other): + if isinstance(other, (set, frozenset)): + return self._args_set == other + return super().__eq__(other) + + __hash__ : Callable[[Basic], Any] = Basic.__hash__ + +_sympy_converter[set] = lambda x: FiniteSet(*x) +_sympy_converter[frozenset] = lambda x: FiniteSet(*x) + + +class SymmetricDifference(Set): + """Represents the set of elements which are in either of the + sets and not in their intersection. + + Examples + ======== + + >>> from sympy import SymmetricDifference, FiniteSet + >>> SymmetricDifference(FiniteSet(1, 2, 3), FiniteSet(3, 4, 5)) + {1, 2, 4, 5} + + See Also + ======== + + Complement, Union + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Symmetric_difference + """ + + is_SymmetricDifference = True + + def __new__(cls, a, b, evaluate=True): + if evaluate: + return SymmetricDifference.reduce(a, b) + + return Basic.__new__(cls, a, b) + + @staticmethod + def reduce(A, B): + result = B._symmetric_difference(A) + if result is not None: + return result + else: + return SymmetricDifference(A, B, evaluate=False) + + def as_relational(self, symbol): + """Rewrite a symmetric_difference in terms of equalities and + logic operators""" + A, B = self.args + + A_rel = A.as_relational(symbol) + B_rel = B.as_relational(symbol) + + return Xor(A_rel, B_rel) + + @property + def is_iterable(self): + if all(arg.is_iterable for arg in self.args): + return True + + def __iter__(self): + + args = self.args + union = roundrobin(*(iter(arg) for arg in args)) + + for item in union: + count = 0 + for s in args: + if item in s: + count += 1 + + if count % 2 == 1: + yield item + + + +class DisjointUnion(Set): + """ Represents the disjoint union (also known as the external disjoint union) + of a finite number of sets. + + Examples + ======== + + >>> from sympy import DisjointUnion, FiniteSet, Interval, Union, Symbol + >>> A = FiniteSet(1, 2, 3) + >>> B = Interval(0, 5) + >>> DisjointUnion(A, B) + DisjointUnion({1, 2, 3}, Interval(0, 5)) + >>> DisjointUnion(A, B).rewrite(Union) + Union(ProductSet({1, 2, 3}, {0}), ProductSet(Interval(0, 5), {1})) + >>> C = FiniteSet(Symbol('x'), Symbol('y'), Symbol('z')) + >>> DisjointUnion(C, C) + DisjointUnion({x, y, z}, {x, y, z}) + >>> DisjointUnion(C, C).rewrite(Union) + ProductSet({x, y, z}, {0, 1}) + + References + ========== + + https://en.wikipedia.org/wiki/Disjoint_union + """ + + def __new__(cls, *sets): + dj_collection = [] + for set_i in sets: + if isinstance(set_i, Set): + dj_collection.append(set_i) + else: + raise TypeError("Invalid input: '%s', input args \ + to DisjointUnion must be Sets" % set_i) + obj = Basic.__new__(cls, *dj_collection) + return obj + + @property + def sets(self): + return self.args + + @property + def is_empty(self): + return fuzzy_and(s.is_empty for s in self.sets) + + @property + def is_finite_set(self): + all_finite = fuzzy_and(s.is_finite_set for s in self.sets) + return fuzzy_or([self.is_empty, all_finite]) + + @property + def is_iterable(self): + if self.is_empty: + return False + iter_flag = True + for set_i in self.sets: + if not set_i.is_empty: + iter_flag = iter_flag and set_i.is_iterable + return iter_flag + + def _eval_rewrite_as_Union(self, *sets, **kwargs): + """ + Rewrites the disjoint union as the union of (``set`` x {``i``}) + where ``set`` is the element in ``sets`` at index = ``i`` + """ + + dj_union = S.EmptySet + index = 0 + for set_i in sets: + if isinstance(set_i, Set): + cross = ProductSet(set_i, FiniteSet(index)) + dj_union = Union(dj_union, cross) + index = index + 1 + return dj_union + + def _contains(self, element): + """ + ``in`` operator for DisjointUnion + + Examples + ======== + + >>> from sympy import Interval, DisjointUnion + >>> D = DisjointUnion(Interval(0, 1), Interval(0, 2)) + >>> (0.5, 0) in D + True + >>> (0.5, 1) in D + True + >>> (1.5, 0) in D + False + >>> (1.5, 1) in D + True + + Passes operation on to constituent sets + """ + if not isinstance(element, Tuple) or len(element) != 2: + return S.false + + if not element[1].is_Integer: + return S.false + + if element[1] >= len(self.sets) or element[1] < 0: + return S.false + + return self.sets[element[1]]._contains(element[0]) + + def _kind(self): + if not self.args: + return SetKind() + elif all(i.kind == self.args[0].kind for i in self.args): + return self.args[0].kind + else: + return SetKind(UndefinedKind) + + def __iter__(self): + if self.is_iterable: + + iters = [] + for i, s in enumerate(self.sets): + iters.append(iproduct(s, {Integer(i)})) + + return iter(roundrobin(*iters)) + else: + raise ValueError("'%s' is not iterable." % self) + + def __len__(self): + """ + Returns the length of the disjoint union, i.e., the number of elements in the set. + + Examples + ======== + + >>> from sympy import FiniteSet, DisjointUnion, EmptySet + >>> D1 = DisjointUnion(FiniteSet(1, 2, 3, 4), EmptySet, FiniteSet(3, 4, 5)) + >>> len(D1) + 7 + >>> D2 = DisjointUnion(FiniteSet(3, 5, 7), EmptySet, FiniteSet(3, 5, 7)) + >>> len(D2) + 6 + >>> D3 = DisjointUnion(EmptySet, EmptySet) + >>> len(D3) + 0 + + Adds up the lengths of the constituent sets. + """ + + if self.is_finite_set: + size = 0 + for set in self.sets: + size += len(set) + return size + else: + raise ValueError("'%s' is not a finite set." % self) + + +def imageset(*args): + r""" + Return an image of the set under transformation ``f``. + + Explanation + =========== + + If this function cannot compute the image, it returns an + unevaluated ImageSet object. + + .. math:: + \{ f(x) \mid x \in \mathrm{self} \} + + Examples + ======== + + >>> from sympy import S, Interval, imageset, sin, Lambda + >>> from sympy.abc import x + + >>> imageset(x, 2*x, Interval(0, 2)) + Interval(0, 4) + + >>> imageset(lambda x: 2*x, Interval(0, 2)) + Interval(0, 4) + + >>> imageset(Lambda(x, sin(x)), Interval(-2, 1)) + ImageSet(Lambda(x, sin(x)), Interval(-2, 1)) + + >>> imageset(sin, Interval(-2, 1)) + ImageSet(Lambda(x, sin(x)), Interval(-2, 1)) + >>> imageset(lambda y: x + y, Interval(-2, 1)) + ImageSet(Lambda(y, x + y), Interval(-2, 1)) + + Expressions applied to the set of Integers are simplified + to show as few negatives as possible and linear expressions + are converted to a canonical form. If this is not desirable + then the unevaluated ImageSet should be used. + + >>> imageset(x, -2*x + 5, S.Integers) + ImageSet(Lambda(x, 2*x + 1), Integers) + + See Also + ======== + + sympy.sets.fancysets.ImageSet + + """ + from .fancysets import ImageSet + from .setexpr import set_function + + if len(args) < 2: + raise ValueError('imageset expects at least 2 args, got: %s' % len(args)) + + if isinstance(args[0], (Symbol, tuple)) and len(args) > 2: + f = Lambda(args[0], args[1]) + set_list = args[2:] + else: + f = args[0] + set_list = args[1:] + + if isinstance(f, Lambda): + pass + elif callable(f): + nargs = getattr(f, 'nargs', {}) + if nargs: + if len(nargs) != 1: + raise NotImplementedError(filldedent(''' + This function can take more than 1 arg + but the potentially complicated set input + has not been analyzed at this point to + know its dimensions. TODO + ''')) + N = nargs.args[0] + if N == 1: + s = 'x' + else: + s = [Symbol('x%i' % i) for i in range(1, N + 1)] + else: + s = inspect.signature(f).parameters + + dexpr = _sympify(f(*[Dummy() for i in s])) + var = tuple(uniquely_named_symbol( + Symbol(i), dexpr) for i in s) + f = Lambda(var, f(*var)) + else: + raise TypeError(filldedent(''' + expecting lambda, Lambda, or FunctionClass, + not \'%s\'.''' % func_name(f))) + + if any(not isinstance(s, Set) for s in set_list): + name = [func_name(s) for s in set_list] + raise ValueError( + 'arguments after mapping should be sets, not %s' % name) + + if len(set_list) == 1: + set = set_list[0] + try: + # TypeError if arg count != set dimensions + r = set_function(f, set) + if r is None: + raise TypeError + if not r: + return r + except TypeError: + r = ImageSet(f, set) + if isinstance(r, ImageSet): + f, set = r.args + + if f.variables[0] == f.expr: + return set + + if isinstance(set, ImageSet): + # XXX: Maybe this should just be: + # f2 = set.lambda + # fun = Lambda(f2.signature, f(*f2.expr)) + # return imageset(fun, *set.base_sets) + if len(set.lamda.variables) == 1 and len(f.variables) == 1: + x = set.lamda.variables[0] + y = f.variables[0] + return imageset( + Lambda(x, f.expr.subs(y, set.lamda.expr)), *set.base_sets) + + if r is not None: + return r + + return ImageSet(f, *set_list) + + +def is_function_invertible_in_set(func, setv): + """ + Checks whether function ``func`` is invertible when the domain is + restricted to set ``setv``. + """ + # Functions known to always be invertible: + if func in (exp, log): + return True + u = Dummy("u") + fdiff = func(u).diff(u) + # monotonous functions: + # TODO: check subsets (`func` in `setv`) + if (fdiff > 0) == True or (fdiff < 0) == True: + return True + # TODO: support more + return None + + +def simplify_union(args): + """ + Simplify a :class:`Union` using known rules. + + Explanation + =========== + + We first start with global rules like 'Merge all FiniteSets' + + Then we iterate through all pairs and ask the constituent sets if they + can simplify themselves with any other constituent. This process depends + on ``union_sets(a, b)`` functions. + """ + from sympy.sets.handlers.union import union_sets + + # ===== Global Rules ===== + if not args: + return S.EmptySet + + for arg in args: + if not isinstance(arg, Set): + raise TypeError("Input args to Union must be Sets") + + # Merge all finite sets + finite_sets = [x for x in args if x.is_FiniteSet] + if len(finite_sets) > 1: + a = (x for set in finite_sets for x in set) + finite_set = FiniteSet(*a) + args = [finite_set] + [x for x in args if not x.is_FiniteSet] + + # ===== Pair-wise Rules ===== + # Here we depend on rules built into the constituent sets + args = set(args) + new_args = True + while new_args: + for s in args: + new_args = False + for t in args - {s}: + new_set = union_sets(s, t) + # This returns None if s does not know how to intersect + # with t. Returns the newly intersected set otherwise + if new_set is not None: + if not isinstance(new_set, set): + new_set = {new_set} + new_args = (args - {s, t}).union(new_set) + break + if new_args: + args = new_args + break + + if len(args) == 1: + return args.pop() + else: + return Union(*args, evaluate=False) + + +def simplify_intersection(args): + """ + Simplify an intersection using known rules. + + Explanation + =========== + + We first start with global rules like + 'if any empty sets return empty set' and 'distribute any unions' + + Then we iterate through all pairs and ask the constituent sets if they + can simplify themselves with any other constituent + """ + + # ===== Global Rules ===== + if not args: + return S.UniversalSet + + for arg in args: + if not isinstance(arg, Set): + raise TypeError("Input args to Union must be Sets") + + # If any EmptySets return EmptySet + if S.EmptySet in args: + return S.EmptySet + + # Handle Finite sets + rv = Intersection._handle_finite_sets(args) + + if rv is not None: + return rv + + # If any of the sets are unions, return a Union of Intersections + for s in args: + if s.is_Union: + other_sets = set(args) - {s} + if len(other_sets) > 0: + other = Intersection(*other_sets) + return Union(*(Intersection(arg, other) for arg in s.args)) + else: + return Union(*s.args) + + for s in args: + if s.is_Complement: + args.remove(s) + other_sets = args + [s.args[0]] + return Complement(Intersection(*other_sets), s.args[1]) + + from sympy.sets.handlers.intersection import intersection_sets + + # At this stage we are guaranteed not to have any + # EmptySets, FiniteSets, or Unions in the intersection + + # ===== Pair-wise Rules ===== + # Here we depend on rules built into the constituent sets + args = set(args) + new_args = True + while new_args: + for s in args: + new_args = False + for t in args - {s}: + new_set = intersection_sets(s, t) + # This returns None if s does not know how to intersect + # with t. Returns the newly intersected set otherwise + + if new_set is not None: + new_args = (args - {s, t}).union({new_set}) + break + if new_args: + args = new_args + break + + if len(args) == 1: + return args.pop() + else: + return Intersection(*args, evaluate=False) + + +def _handle_finite_sets(op, x, y, commutative): + # Handle finite sets: + fs_args, other = sift([x, y], lambda x: isinstance(x, FiniteSet), binary=True) + if len(fs_args) == 2: + return FiniteSet(*[op(i, j) for i in fs_args[0] for j in fs_args[1]]) + elif len(fs_args) == 1: + sets = [_apply_operation(op, other[0], i, commutative) for i in fs_args[0]] + return Union(*sets) + else: + return None + + +def _apply_operation(op, x, y, commutative): + from .fancysets import ImageSet + d = Dummy('d') + + out = _handle_finite_sets(op, x, y, commutative) + if out is None: + out = op(x, y) + + if out is None and commutative: + out = op(y, x) + if out is None: + _x, _y = symbols("x y") + if isinstance(x, Set) and not isinstance(y, Set): + out = ImageSet(Lambda(d, op(d, y)), x).doit() + elif not isinstance(x, Set) and isinstance(y, Set): + out = ImageSet(Lambda(d, op(x, d)), y).doit() + else: + out = ImageSet(Lambda((_x, _y), op(_x, _y)), x, y) + return out + + +def set_add(x, y): + from sympy.sets.handlers.add import _set_add + return _apply_operation(_set_add, x, y, commutative=True) + + +def set_sub(x, y): + from sympy.sets.handlers.add import _set_sub + return _apply_operation(_set_sub, x, y, commutative=False) + + +def set_mul(x, y): + from sympy.sets.handlers.mul import _set_mul + return _apply_operation(_set_mul, x, y, commutative=True) + + +def set_div(x, y): + from sympy.sets.handlers.mul import _set_div + return _apply_operation(_set_div, x, y, commutative=False) + + +def set_pow(x, y): + from sympy.sets.handlers.power import _set_pow + return _apply_operation(_set_pow, x, y, commutative=False) + + +def set_function(f, x): + from sympy.sets.handlers.functions import _set_function + return _set_function(f, x) + + +class SetKind(Kind): + """ + SetKind is kind for all Sets + + Every instance of Set will have kind ``SetKind`` parametrised by the kind + of the elements of the ``Set``. The kind of the elements might be + ``NumberKind``, or ``TupleKind`` or something else. When not all elements + have the same kind then the kind of the elements will be given as + ``UndefinedKind``. + + Parameters + ========== + + element_kind: Kind (optional) + The kind of the elements of the set. In a well defined set all elements + will have the same kind. Otherwise the kind should + :class:`sympy.core.kind.UndefinedKind`. The ``element_kind`` argument is optional but + should only be omitted in the case of ``EmptySet`` whose kind is simply + ``SetKind()`` + + Examples + ======== + + >>> from sympy import Interval + >>> Interval(1, 2).kind + SetKind(NumberKind) + >>> Interval(1,2).kind.element_kind + NumberKind + + See Also + ======== + + sympy.core.kind.NumberKind + sympy.matrices.kind.MatrixKind + sympy.core.containers.TupleKind + """ + def __new__(cls, element_kind=None): + obj = super().__new__(cls, element_kind) + obj.element_kind = element_kind + return obj + + def __repr__(self): + if not self.element_kind: + return "SetKind()" + else: + return "SetKind(%s)" % self.element_kind diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/__init__.py b/.venv/lib/python3.13/site-packages/sympy/solvers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..02cfb35a765c748e70ecc36dd78d8d4432118c64 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/__init__.py @@ -0,0 +1,75 @@ +"""A module for solving all kinds of equations. + + Examples + ======== + + >>> from sympy.solvers import solve + >>> from sympy.abc import x + >>> solve(((x + 1)**5).expand(), x) + [-1] +""" +from sympy.core.assumptions import check_assumptions, failing_assumptions + +from .solvers import solve, solve_linear_system, solve_linear_system_LU, \ + solve_undetermined_coeffs, nsolve, solve_linear, checksol, \ + det_quick, inv_quick + +from sympy.solvers.diophantine.diophantine import diophantine + +from .recurr import rsolve, rsolve_poly, rsolve_ratio, rsolve_hyper + +from .ode import checkodesol, classify_ode, dsolve, \ + homogeneous_order + +from .polysys import solve_poly_system, solve_triangulated, factor_system + +from .pde import pde_separate, pde_separate_add, pde_separate_mul, \ + pdsolve, classify_pde, checkpdesol + +from .deutils import ode_order + +from .inequalities import reduce_inequalities, reduce_abs_inequality, \ + reduce_abs_inequalities, solve_poly_inequality, solve_rational_inequalities, solve_univariate_inequality + +from .decompogen import decompogen + +from .solveset import solveset, linsolve, linear_eq_to_matrix, nonlinsolve, substitution + +from .simplex import lpmin, lpmax, linprog + +# This is here instead of sympy/sets/__init__.py to avoid circular import issues +from ..core.singleton import S +Complexes = S.Complexes + +__all__ = [ + 'solve', 'solve_linear_system', 'solve_linear_system_LU', + 'solve_undetermined_coeffs', 'nsolve', 'solve_linear', 'checksol', + 'det_quick', 'inv_quick', 'check_assumptions', 'failing_assumptions', + + 'diophantine', + + 'rsolve', 'rsolve_poly', 'rsolve_ratio', 'rsolve_hyper', + + 'checkodesol', 'classify_ode', 'dsolve', 'homogeneous_order', + + 'solve_poly_system', 'solve_triangulated', 'factor_system', + + 'pde_separate', 'pde_separate_add', 'pde_separate_mul', 'pdsolve', + 'classify_pde', 'checkpdesol', + + 'ode_order', + + 'reduce_inequalities', 'reduce_abs_inequality', 'reduce_abs_inequalities', + 'solve_poly_inequality', 'solve_rational_inequalities', + 'solve_univariate_inequality', + + 'decompogen', + + 'solveset', 'linsolve', 'linear_eq_to_matrix', 'nonlinsolve', + 'substitution', + + # This is here instead of sympy/sets/__init__.py to avoid circular import issues + 'Complexes', + + 'lpmin', 'lpmax', 'linprog' +] diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/bivariate.py b/.venv/lib/python3.13/site-packages/sympy/solvers/bivariate.py new file mode 100644 index 0000000000000000000000000000000000000000..72f8e266a16634fa65366e1058543dfe2171ba1c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/bivariate.py @@ -0,0 +1,509 @@ +from sympy.core.add import Add +from sympy.core.exprtools import factor_terms +from sympy.core.function import expand_log, _mexpand +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.sorting import ordered +from sympy.core.symbol import Dummy +from sympy.functions.elementary.exponential import (LambertW, exp, log) +from sympy.functions.elementary.miscellaneous import root +from sympy.polys.polyroots import roots +from sympy.polys.polytools import Poly, factor +from sympy.simplify.simplify import separatevars +from sympy.simplify.radsimp import collect +from sympy.simplify.simplify import powsimp +from sympy.solvers.solvers import solve, _invert +from sympy.utilities.iterables import uniq + + +def _filtered_gens(poly, symbol): + """process the generators of ``poly``, returning the set of generators that + have ``symbol``. If there are two generators that are inverses of each other, + prefer the one that has no denominator. + + Examples + ======== + + >>> from sympy.solvers.bivariate import _filtered_gens + >>> from sympy import Poly, exp + >>> from sympy.abc import x + >>> _filtered_gens(Poly(x + 1/x + exp(x)), x) + {x, exp(x)} + + """ + # TODO it would be good to pick the smallest divisible power + # instead of the base for something like x**4 + x**2 --> + # return x**2 not x + gens = {g for g in poly.gens if symbol in g.free_symbols} + for g in list(gens): + ag = 1/g + if g in gens and ag in gens: + if ag.as_numer_denom()[1] is not S.One: + g = ag + gens.remove(g) + return gens + + +def _mostfunc(lhs, func, X=None): + """Returns the term in lhs which contains the most of the + func-type things e.g. log(log(x)) wins over log(x) if both terms appear. + + ``func`` can be a function (exp, log, etc...) or any other SymPy object, + like Pow. + + If ``X`` is not ``None``, then the function returns the term composed with the + most ``func`` having the specified variable. + + Examples + ======== + + >>> from sympy.solvers.bivariate import _mostfunc + >>> from sympy import exp + >>> from sympy.abc import x, y + >>> _mostfunc(exp(x) + exp(exp(x) + 2), exp) + exp(exp(x) + 2) + >>> _mostfunc(exp(x) + exp(exp(y) + 2), exp) + exp(exp(y) + 2) + >>> _mostfunc(exp(x) + exp(exp(y) + 2), exp, x) + exp(x) + >>> _mostfunc(x, exp, x) is None + True + >>> _mostfunc(exp(x) + exp(x*y), exp, x) + exp(x) + """ + fterms = [tmp for tmp in lhs.atoms(func) if (not X or + X.is_Symbol and X in tmp.free_symbols or + not X.is_Symbol and tmp.has(X))] + if len(fterms) == 1: + return fterms[0] + elif fterms: + return max(list(ordered(fterms)), key=lambda x: x.count(func)) + return None + + +def _linab(arg, symbol): + """Return ``a, b, X`` assuming ``arg`` can be written as ``a*X + b`` + where ``X`` is a symbol-dependent factor and ``a`` and ``b`` are + independent of ``symbol``. + + Examples + ======== + + >>> from sympy.solvers.bivariate import _linab + >>> from sympy.abc import x, y + >>> from sympy import exp, S + >>> _linab(S(2), x) + (2, 0, 1) + >>> _linab(2*x, x) + (2, 0, x) + >>> _linab(y + y*x + 2*x, x) + (y + 2, y, x) + >>> _linab(3 + 2*exp(x), x) + (2, 3, exp(x)) + """ + arg = factor_terms(arg.expand()) + ind, dep = arg.as_independent(symbol) + if arg.is_Mul and dep.is_Add: + a, b, x = _linab(dep, symbol) + return ind*a, ind*b, x + if not arg.is_Add: + b = 0 + a, x = ind, dep + else: + b = ind + a, x = separatevars(dep).as_independent(symbol, as_Add=False) + if x.could_extract_minus_sign(): + a = -a + x = -x + return a, b, x + + +def _lambert(eq, x): + """ + Given an expression assumed to be in the form + ``F(X, a..f) = a*log(b*X + c) + d*X + f = 0`` + where X = g(x) and x = g^-1(X), return the Lambert solution, + ``x = g^-1(-c/b + (a/d)*W(d/(a*b)*exp(c*d/a/b)*exp(-f/a)))``. + """ + eq = _mexpand(expand_log(eq)) + mainlog = _mostfunc(eq, log, x) + if not mainlog: + return [] # violated assumptions + other = eq.subs(mainlog, 0) + if isinstance(-other, log): + eq = (eq - other).subs(mainlog, mainlog.args[0]) + mainlog = mainlog.args[0] + if not isinstance(mainlog, log): + return [] # violated assumptions + other = -(-other).args[0] + eq += other + if x not in other.free_symbols: + return [] # violated assumptions + d, f, X2 = _linab(other, x) + logterm = collect(eq - other, mainlog) + a = logterm.as_coefficient(mainlog) + if a is None or x in a.free_symbols: + return [] # violated assumptions + logarg = mainlog.args[0] + b, c, X1 = _linab(logarg, x) + if X1 != X2: + return [] # violated assumptions + + # invert the generator X1 so we have x(u) + u = Dummy('rhs') + xusolns = solve(X1 - u, x) + + # There are infinitely many branches for LambertW + # but only branches for k = -1 and 0 might be real. The k = 0 + # branch is real and the k = -1 branch is real if the LambertW argument + # in in range [-1/e, 0]. Since `solve` does not return infinite + # solutions we will only include the -1 branch if it tests as real. + # Otherwise, inclusion of any LambertW in the solution indicates to + # the user that there are imaginary solutions corresponding to + # different k values. + lambert_real_branches = [-1, 0] + sol = [] + + # solution of the given Lambert equation is like + # sol = -c/b + (a/d)*LambertW(arg, k), + # where arg = d/(a*b)*exp((c*d-b*f)/a/b) and k in lambert_real_branches. + # Instead of considering the single arg, `d/(a*b)*exp((c*d-b*f)/a/b)`, + # the individual `p` roots obtained when writing `exp((c*d-b*f)/a/b)` + # as `exp(A/p) = exp(A)**(1/p)`, where `p` is an Integer, are used. + + # calculating args for LambertW + num, den = ((c*d-b*f)/a/b).as_numer_denom() + p, den = den.as_coeff_Mul() + e = exp(num/den) + t = Dummy('t') + args = [d/(a*b)*t for t in roots(t**p - e, t).keys()] + + # calculating solutions from args + for arg in args: + for k in lambert_real_branches: + w = LambertW(arg, k) + if k and not w.is_real: + continue + rhs = -c/b + (a/d)*w + + sol.extend(xu.subs(u, rhs) for xu in xusolns) + return sol + + +def _solve_lambert(f, symbol, gens): + """Return solution to ``f`` if it is a Lambert-type expression + else raise NotImplementedError. + + For ``f(X, a..f) = a*log(b*X + c) + d*X - f = 0`` the solution + for ``X`` is ``X = -c/b + (a/d)*W(d/(a*b)*exp(c*d/a/b)*exp(f/a))``. + There are a variety of forms for `f(X, a..f)` as enumerated below: + + 1a1) + if B**B = R for R not in [0, 1] (since those cases would already + be solved before getting here) then log of both sides gives + log(B) + log(log(B)) = log(log(R)) and + X = log(B), a = 1, b = 1, c = 0, d = 1, f = log(log(R)) + 1a2) + if B*(b*log(B) + c)**a = R then log of both sides gives + log(B) + a*log(b*log(B) + c) = log(R) and + X = log(B), d=1, f=log(R) + 1b) + if a*log(b*B + c) + d*B = R and + X = B, f = R + 2a) + if (b*B + c)*exp(d*B + g) = R then log of both sides gives + log(b*B + c) + d*B + g = log(R) and + X = B, a = 1, f = log(R) - g + 2b) + if g*exp(d*B + h) - b*B = c then the log form is + log(g) + d*B + h - log(b*B + c) = 0 and + X = B, a = -1, f = -h - log(g) + 3) + if d*p**(a*B + g) - b*B = c then the log form is + log(d) + (a*B + g)*log(p) - log(b*B + c) = 0 and + X = B, a = -1, d = a*log(p), f = -log(d) - g*log(p) + """ + + def _solve_even_degree_expr(expr, t, symbol): + """Return the unique solutions of equations derived from + ``expr`` by replacing ``t`` with ``+/- symbol``. + + Parameters + ========== + + expr : Expr + The expression which includes a dummy variable t to be + replaced with +symbol and -symbol. + + symbol : Symbol + The symbol for which a solution is being sought. + + Returns + ======= + + List of unique solution of the two equations generated by + replacing ``t`` with positive and negative ``symbol``. + + Notes + ===== + + If ``expr = 2*log(t) + x/2` then solutions for + ``2*log(x) + x/2 = 0`` and ``2*log(-x) + x/2 = 0`` are + returned by this function. Though this may seem + counter-intuitive, one must note that the ``expr`` being + solved here has been derived from a different expression. For + an expression like ``eq = x**2*g(x) = 1``, if we take the + log of both sides we obtain ``log(x**2) + log(g(x)) = 0``. If + x is positive then this simplifies to + ``2*log(x) + log(g(x)) = 0``; the Lambert-solving routines will + return solutions for this, but we must also consider the + solutions for ``2*log(-x) + log(g(x))`` since those must also + be a solution of ``eq`` which has the same value when the ``x`` + in ``x**2`` is negated. If `g(x)` does not have even powers of + symbol then we do not want to replace the ``x`` there with + ``-x``. So the role of the ``t`` in the expression received by + this function is to mark where ``+/-x`` should be inserted + before obtaining the Lambert solutions. + + """ + nlhs, plhs = [ + expr.xreplace({t: sgn*symbol}) for sgn in (-1, 1)] + sols = _solve_lambert(nlhs, symbol, gens) + if plhs != nlhs: + sols.extend(_solve_lambert(plhs, symbol, gens)) + # uniq is needed for a case like + # 2*log(t) - log(-z**2) + log(z + log(x) + log(z)) + # where substituting t with +/-x gives all the same solution; + # uniq, rather than list(set()), is used to maintain canonical + # order + return list(uniq(sols)) + + nrhs, lhs = f.as_independent(symbol, as_Add=True) + rhs = -nrhs + + lamcheck = [tmp for tmp in gens + if (tmp.func in [exp, log] or + (tmp.is_Pow and symbol in tmp.exp.free_symbols))] + if not lamcheck: + raise NotImplementedError() + + if lhs.is_Add or lhs.is_Mul: + # replacing all even_degrees of symbol with dummy variable t + # since these will need special handling; non-Add/Mul do not + # need this handling + t = Dummy('t', **symbol.assumptions0) + lhs = lhs.replace( + lambda i: # find symbol**even + i.is_Pow and i.base == symbol and i.exp.is_even, + lambda i: # replace t**even + t**i.exp) + + if lhs.is_Add and lhs.has(t): + t_indep = lhs.subs(t, 0) + t_term = lhs - t_indep + _rhs = rhs - t_indep + if not t_term.is_Add and _rhs and not ( + t_term.has(S.ComplexInfinity, S.NaN)): + eq = expand_log(log(t_term) - log(_rhs)) + return _solve_even_degree_expr(eq, t, symbol) + elif lhs.is_Mul and rhs: + # this needs to happen whether t is present or not + lhs = expand_log(log(lhs), force=True) + rhs = log(rhs) + if lhs.has(t) and lhs.is_Add: + # it expanded from Mul to Add + eq = lhs - rhs + return _solve_even_degree_expr(eq, t, symbol) + + # restore symbol in lhs + lhs = lhs.xreplace({t: symbol}) + + lhs = powsimp(factor(lhs, deep=True)) + + # make sure we have inverted as completely as possible + r = Dummy() + i, lhs = _invert(lhs - r, symbol) + rhs = i.xreplace({r: rhs}) + + # For the first forms: + # + # 1a1) B**B = R will arrive here as B*log(B) = log(R) + # lhs is Mul so take log of both sides: + # log(B) + log(log(B)) = log(log(R)) + # 1a2) B*(b*log(B) + c)**a = R will arrive unchanged so + # lhs is Mul, so take log of both sides: + # log(B) + a*log(b*log(B) + c) = log(R) + # 1b) d*log(a*B + b) + c*B = R will arrive unchanged so + # lhs is Add, so isolate c*B and expand log of both sides: + # log(c) + log(B) = log(R - d*log(a*B + b)) + + soln = [] + if not soln: + mainlog = _mostfunc(lhs, log, symbol) + if mainlog: + if lhs.is_Mul and rhs != 0: + soln = _lambert(log(lhs) - log(rhs), symbol) + elif lhs.is_Add: + other = lhs.subs(mainlog, 0) + if other and not other.is_Add and [ + tmp for tmp in other.atoms(Pow) + if symbol in tmp.free_symbols]: + if not rhs: + diff = log(other) - log(other - lhs) + else: + diff = log(lhs - other) - log(rhs - other) + soln = _lambert(expand_log(diff), symbol) + else: + #it's ready to go + soln = _lambert(lhs - rhs, symbol) + + # For the next forms, + # + # collect on main exp + # 2a) (b*B + c)*exp(d*B + g) = R + # lhs is mul, so take log of both sides: + # log(b*B + c) + d*B = log(R) - g + # 2b) g*exp(d*B + h) - b*B = R + # lhs is add, so add b*B to both sides, + # take the log of both sides and rearrange to give + # log(R + b*B) - d*B = log(g) + h + + if not soln: + mainexp = _mostfunc(lhs, exp, symbol) + if mainexp: + lhs = collect(lhs, mainexp) + if lhs.is_Mul and rhs != 0: + soln = _lambert(expand_log(log(lhs) - log(rhs)), symbol) + elif lhs.is_Add: + # move all but mainexp-containing term to rhs + other = lhs.subs(mainexp, 0) + mainterm = lhs - other + rhs = rhs - other + if (mainterm.could_extract_minus_sign() and + rhs.could_extract_minus_sign()): + mainterm *= -1 + rhs *= -1 + diff = log(mainterm) - log(rhs) + soln = _lambert(expand_log(diff), symbol) + + # For the last form: + # + # 3) d*p**(a*B + g) - b*B = c + # collect on main pow, add b*B to both sides, + # take log of both sides and rearrange to give + # a*B*log(p) - log(b*B + c) = -log(d) - g*log(p) + if not soln: + mainpow = _mostfunc(lhs, Pow, symbol) + if mainpow and symbol in mainpow.exp.free_symbols: + lhs = collect(lhs, mainpow) + if lhs.is_Mul and rhs != 0: + # b*B = 0 + soln = _lambert(expand_log(log(lhs) - log(rhs)), symbol) + elif lhs.is_Add: + # move all but mainpow-containing term to rhs + other = lhs.subs(mainpow, 0) + mainterm = lhs - other + rhs = rhs - other + diff = log(mainterm) - log(rhs) + soln = _lambert(expand_log(diff), symbol) + + if not soln: + raise NotImplementedError('%s does not appear to have a solution in ' + 'terms of LambertW' % f) + + return list(ordered(soln)) + + +def bivariate_type(f, x, y, *, first=True): + """Given an expression, f, 3 tests will be done to see what type + of composite bivariate it might be, options for u(x, y) are:: + + x*y + x+y + x*y+x + x*y+y + + If it matches one of these types, ``u(x, y)``, ``P(u)`` and dummy + variable ``u`` will be returned. Solving ``P(u)`` for ``u`` and + equating the solutions to ``u(x, y)`` and then solving for ``x`` or + ``y`` is equivalent to solving the original expression for ``x`` or + ``y``. If ``x`` and ``y`` represent two functions in the same + variable, e.g. ``x = g(t)`` and ``y = h(t)``, then if ``u(x, y) - p`` + can be solved for ``t`` then these represent the solutions to + ``P(u) = 0`` when ``p`` are the solutions of ``P(u) = 0``. + + Only positive values of ``u`` are considered. + + Examples + ======== + + >>> from sympy import solve + >>> from sympy.solvers.bivariate import bivariate_type + >>> from sympy.abc import x, y + >>> eq = (x**2 - 3).subs(x, x + y) + >>> bivariate_type(eq, x, y) + (x + y, _u**2 - 3, _u) + >>> uxy, pu, u = _ + >>> usol = solve(pu, u); usol + [sqrt(3)] + >>> [solve(uxy - s) for s in solve(pu, u)] + [[{x: -y + sqrt(3)}]] + >>> all(eq.subs(s).equals(0) for sol in _ for s in sol) + True + + """ + + u = Dummy('u', positive=True) + + if first: + p = Poly(f, x, y) + f = p.as_expr() + _x = Dummy() + _y = Dummy() + rv = bivariate_type(Poly(f.subs({x: _x, y: _y}), _x, _y), _x, _y, first=False) + if rv: + reps = {_x: x, _y: y} + return rv[0].xreplace(reps), rv[1].xreplace(reps), rv[2] + return + + p = f + f = p.as_expr() + + # f(x*y) + args = Add.make_args(p.as_expr()) + new = [] + for a in args: + a = _mexpand(a.subs(x, u/y)) + free = a.free_symbols + if x in free or y in free: + break + new.append(a) + else: + return x*y, Add(*new), u + + def ok(f, v, c): + new = _mexpand(f.subs(v, c)) + free = new.free_symbols + return None if (x in free or y in free) else new + + # f(a*x + b*y) + new = [] + d = p.degree(x) + if p.degree(y) == d: + a = root(p.coeff_monomial(x**d), d) + b = root(p.coeff_monomial(y**d), d) + new = ok(f, x, (u - b*y)/a) + if new is not None: + return a*x + b*y, new, u + + # f(a*x*y + b*y) + new = [] + d = p.degree(x) + if p.degree(y) == d: + for itry in range(2): + a = root(p.coeff_monomial(x**d*y**d), d) + b = root(p.coeff_monomial(y**d), d) + new = ok(f, x, (u - b*y)/a/y) + if new is not None: + return a*x*y + b*y, new, u + x, y = y, x diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/decompogen.py b/.venv/lib/python3.13/site-packages/sympy/solvers/decompogen.py new file mode 100644 index 0000000000000000000000000000000000000000..ec1b3b683511a34e6f98b9839d112b87517390d8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/decompogen.py @@ -0,0 +1,126 @@ +from sympy.core import (Function, Pow, sympify, Expr) +from sympy.core.relational import Relational +from sympy.core.singleton import S +from sympy.polys import Poly, decompose +from sympy.utilities.misc import func_name +from sympy.functions.elementary.miscellaneous import Min, Max + + +def decompogen(f, symbol): + """ + Computes General functional decomposition of ``f``. + Given an expression ``f``, returns a list ``[f_1, f_2, ..., f_n]``, + where:: + f = f_1 o f_2 o ... f_n = f_1(f_2(... f_n)) + + Note: This is a General decomposition function. It also decomposes + Polynomials. For only Polynomial decomposition see ``decompose`` in polys. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import decompogen, sqrt, sin, cos + >>> decompogen(sin(cos(x)), x) + [sin(x), cos(x)] + >>> decompogen(sin(x)**2 + sin(x) + 1, x) + [x**2 + x + 1, sin(x)] + >>> decompogen(sqrt(6*x**2 - 5), x) + [sqrt(x), 6*x**2 - 5] + >>> decompogen(sin(sqrt(cos(x**2 + 1))), x) + [sin(x), sqrt(x), cos(x), x**2 + 1] + >>> decompogen(x**4 + 2*x**3 - x - 1, x) + [x**2 - x - 1, x**2 + x] + + """ + f = sympify(f) + if not isinstance(f, Expr) or isinstance(f, Relational): + raise TypeError('expecting Expr but got: `%s`' % func_name(f)) + if symbol not in f.free_symbols: + return [f] + + + # ===== Simple Functions ===== # + if isinstance(f, (Function, Pow)): + if f.is_Pow and f.base == S.Exp1: + arg = f.exp + else: + arg = f.args[0] + if arg == symbol: + return [f] + return [f.subs(arg, symbol)] + decompogen(arg, symbol) + + # ===== Min/Max Functions ===== # + if isinstance(f, (Min, Max)): + args = list(f.args) + d0 = None + for i, a in enumerate(args): + if not a.has_free(symbol): + continue + d = decompogen(a, symbol) + if len(d) == 1: + d = [symbol] + d + if d0 is None: + d0 = d[1:] + elif d[1:] != d0: + # decomposition is not the same for each arg: + # mark as having no decomposition + d = [symbol] + break + args[i] = d[0] + if d[0] == symbol: + return [f] + return [f.func(*args)] + d0 + + # ===== Convert to Polynomial ===== # + fp = Poly(f) + gens = list(filter(lambda x: symbol in x.free_symbols, fp.gens)) + + if len(gens) == 1 and gens[0] != symbol: + f1 = f.subs(gens[0], symbol) + f2 = gens[0] + return [f1] + decompogen(f2, symbol) + + # ===== Polynomial decompose() ====== # + try: + return decompose(f) + except ValueError: + return [f] + + +def compogen(g_s, symbol): + """ + Returns the composition of functions. + Given a list of functions ``g_s``, returns their composition ``f``, + where: + f = g_1 o g_2 o .. o g_n + + Note: This is a General composition function. It also composes Polynomials. + For only Polynomial composition see ``compose`` in polys. + + Examples + ======== + + >>> from sympy.solvers.decompogen import compogen + >>> from sympy.abc import x + >>> from sympy import sqrt, sin, cos + >>> compogen([sin(x), cos(x)], x) + sin(cos(x)) + >>> compogen([x**2 + x + 1, sin(x)], x) + sin(x)**2 + sin(x) + 1 + >>> compogen([sqrt(x), 6*x**2 - 5], x) + sqrt(6*x**2 - 5) + >>> compogen([sin(x), sqrt(x), cos(x), x**2 + 1], x) + sin(sqrt(cos(x**2 + 1))) + >>> compogen([x**2 - x - 1, x**2 + x], x) + -x**2 - x + (x**2 + x)**2 - 1 + """ + if len(g_s) == 1: + return g_s[0] + + foo = g_s[0].subs(symbol, g_s[1]) + + if len(g_s) == 2: + return foo + + return compogen([foo] + g_s[2:], symbol) diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/deutils.py b/.venv/lib/python3.13/site-packages/sympy/solvers/deutils.py new file mode 100644 index 0000000000000000000000000000000000000000..b13f37b5004fe7bce2545ad2c788ec3feae56025 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/deutils.py @@ -0,0 +1,273 @@ +"""Utility functions for classifying and solving +ordinary and partial differential equations. + +Contains +======== +_preprocess +ode_order +_desolve + +""" +from sympy.core import Pow +from sympy.core.function import Derivative, AppliedUndef +from sympy.core.relational import Equality +from sympy.core.symbol import Wild + +def _preprocess(expr, func=None, hint='_Integral'): + """Prepare expr for solving by making sure that differentiation + is done so that only func remains in unevaluated derivatives and + (if hint does not end with _Integral) that doit is applied to all + other derivatives. If hint is None, do not do any differentiation. + (Currently this may cause some simple differential equations to + fail.) + + In case func is None, an attempt will be made to autodetect the + function to be solved for. + + >>> from sympy.solvers.deutils import _preprocess + >>> from sympy import Derivative, Function + >>> from sympy.abc import x, y, z + >>> f, g = map(Function, 'fg') + + If f(x)**p == 0 and p>0 then we can solve for f(x)=0 + >>> _preprocess((f(x).diff(x)-4)**5, f(x)) + (Derivative(f(x), x) - 4, f(x)) + + Apply doit to derivatives that contain more than the function + of interest: + + >>> _preprocess(Derivative(f(x) + x, x)) + (Derivative(f(x), x) + 1, f(x)) + + Do others if the differentiation variable(s) intersect with those + of the function of interest or contain the function of interest: + + >>> _preprocess(Derivative(g(x), y, z), f(y)) + (0, f(y)) + >>> _preprocess(Derivative(f(y), z), f(y)) + (0, f(y)) + + Do others if the hint does not end in '_Integral' (the default + assumes that it does): + + >>> _preprocess(Derivative(g(x), y), f(x)) + (Derivative(g(x), y), f(x)) + >>> _preprocess(Derivative(f(x), y), f(x), hint='') + (0, f(x)) + + Do not do any derivatives if hint is None: + + >>> eq = Derivative(f(x) + 1, x) + Derivative(f(x), y) + >>> _preprocess(eq, f(x), hint=None) + (Derivative(f(x) + 1, x) + Derivative(f(x), y), f(x)) + + If it's not clear what the function of interest is, it must be given: + + >>> eq = Derivative(f(x) + g(x), x) + >>> _preprocess(eq, g(x)) + (Derivative(f(x), x) + Derivative(g(x), x), g(x)) + >>> try: _preprocess(eq) + ... except ValueError: print("A ValueError was raised.") + A ValueError was raised. + + """ + if isinstance(expr, Pow): + # if f(x)**p=0 then f(x)=0 (p>0) + if (expr.exp).is_positive: + expr = expr.base + derivs = expr.atoms(Derivative) + if not func: + funcs = set().union(*[d.atoms(AppliedUndef) for d in derivs]) + if len(funcs) != 1: + raise ValueError('The function cannot be ' + 'automatically detected for %s.' % expr) + func = funcs.pop() + fvars = set(func.args) + if hint is None: + return expr, func + reps = [(d, d.doit()) for d in derivs if not hint.endswith('_Integral') or + d.has(func) or set(d.variables) & fvars] + eq = expr.subs(reps) + return eq, func + + +def ode_order(expr, func): + """ + Returns the order of a given differential + equation with respect to func. + + This function is implemented recursively. + + Examples + ======== + + >>> from sympy import Function + >>> from sympy.solvers.deutils import ode_order + >>> from sympy.abc import x + >>> f, g = map(Function, ['f', 'g']) + >>> ode_order(f(x).diff(x, 2) + f(x).diff(x)**2 + + ... f(x).diff(x), f(x)) + 2 + >>> ode_order(f(x).diff(x, 2) + g(x).diff(x, 3), f(x)) + 2 + >>> ode_order(f(x).diff(x, 2) + g(x).diff(x, 3), g(x)) + 3 + + """ + a = Wild('a', exclude=[func]) + if expr.match(a): + return 0 + + if isinstance(expr, Derivative): + if expr.args[0] == func: + return len(expr.variables) + else: + args = expr.args[0].args + rv = len(expr.variables) + if args: + rv += max(ode_order(_, func) for _ in args) + return rv + else: + return max(ode_order(_, func) for _ in expr.args) if expr.args else 0 + + +def _desolve(eq, func=None, hint="default", ics=None, simplify=True, *, prep=True, **kwargs): + """This is a helper function to dsolve and pdsolve in the ode + and pde modules. + + If the hint provided to the function is "default", then a dict with + the following keys are returned + + 'func' - It provides the function for which the differential equation + has to be solved. This is useful when the expression has + more than one function in it. + + 'default' - The default key as returned by classifier functions in ode + and pde.py + + 'hint' - The hint given by the user for which the differential equation + is to be solved. If the hint given by the user is 'default', + then the value of 'hint' and 'default' is the same. + + 'order' - The order of the function as returned by ode_order + + 'match' - It returns the match as given by the classifier functions, for + the default hint. + + If the hint provided to the function is not "default" and is not in + ('all', 'all_Integral', 'best'), then a dict with the above mentioned keys + is returned along with the keys which are returned when dict in + classify_ode or classify_pde is set True + + If the hint given is in ('all', 'all_Integral', 'best'), then this function + returns a nested dict, with the keys, being the set of classified hints + returned by classifier functions, and the values being the dict of form + as mentioned above. + + Key 'eq' is a common key to all the above mentioned hints which returns an + expression if eq given by user is an Equality. + + See Also + ======== + classify_ode(ode.py) + classify_pde(pde.py) + """ + if isinstance(eq, Equality): + eq = eq.lhs - eq.rhs + + # preprocess the equation and find func if not given + if prep or func is None: + eq, func = _preprocess(eq, func) + prep = False + + # type is an argument passed by the solve functions in ode and pde.py + # that identifies whether the function caller is an ordinary + # or partial differential equation. Accordingly corresponding + # changes are made in the function. + type = kwargs.get('type', None) + xi = kwargs.get('xi') + eta = kwargs.get('eta') + x0 = kwargs.get('x0', 0) + terms = kwargs.get('n') + + if type == 'ode': + from sympy.solvers.ode import classify_ode, allhints + classifier = classify_ode + string = 'ODE ' + dummy = '' + + elif type == 'pde': + from sympy.solvers.pde import classify_pde, allhints + classifier = classify_pde + string = 'PDE ' + dummy = 'p' + + # Magic that should only be used internally. Prevents classify_ode from + # being called more than it needs to be by passing its results through + # recursive calls. + if kwargs.get('classify', True): + hints = classifier(eq, func, dict=True, ics=ics, xi=xi, eta=eta, + n=terms, x0=x0, hint=hint, prep=prep) + + else: + # Here is what all this means: + # + # hint: The hint method given to _desolve() by the user. + # hints: The dictionary of hints that match the DE, along with other + # information (including the internal pass-through magic). + # default: The default hint to return, the first hint from allhints + # that matches the hint; obtained from classify_ode(). + # match: Dictionary containing the match dictionary for each hint + # (the parts of the DE for solving). When going through the + # hints in "all", this holds the match string for the current + # hint. + # order: The order of the DE, as determined by ode_order(). + hints = kwargs.get('hint', + {'default': hint, + hint: kwargs['match'], + 'order': kwargs['order']}) + if not hints['default']: + # classify_ode will set hints['default'] to None if no hints match. + if hint not in allhints and hint != 'default': + raise ValueError("Hint not recognized: " + hint) + elif hint not in hints['ordered_hints'] and hint != 'default': + raise ValueError(string + str(eq) + " does not match hint " + hint) + # If dsolve can't solve the purely algebraic equation then dsolve will raise + # ValueError + elif hints['order'] == 0: + raise ValueError( + str(eq) + " is not a solvable differential equation in " + str(func)) + else: + raise NotImplementedError(dummy + "solve" + ": Cannot solve " + str(eq)) + if hint == 'default': + return _desolve(eq, func, ics=ics, hint=hints['default'], simplify=simplify, + prep=prep, x0=x0, classify=False, order=hints['order'], + match=hints[hints['default']], xi=xi, eta=eta, n=terms, type=type) + elif hint in ('all', 'all_Integral', 'best'): + retdict = {} + gethints = set(hints) - {'order', 'default', 'ordered_hints'} + if hint == 'all_Integral': + for i in hints: + if i.endswith('_Integral'): + gethints.remove(i.removesuffix('_Integral')) + # special cases + for k in ["1st_homogeneous_coeff_best", "1st_power_series", + "lie_group", "2nd_power_series_ordinary", "2nd_power_series_regular"]: + if k in gethints: + gethints.remove(k) + for i in gethints: + sol = _desolve(eq, func, ics=ics, hint=i, x0=x0, simplify=simplify, prep=prep, + classify=False, n=terms, order=hints['order'], match=hints[i], type=type) + retdict[i] = sol + retdict['all'] = True + retdict['eq'] = eq + return retdict + elif hint not in allhints: # and hint not in ('default', 'ordered_hints'): + raise ValueError("Hint not recognized: " + hint) + elif hint not in hints: + raise ValueError(string + str(eq) + " does not match hint " + hint) + else: + # Key added to identify the hint needed to solve the equation + hints['hint'] = hint + hints.update({'func': func, 'eq': eq}) + return hints diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/inequalities.py b/.venv/lib/python3.13/site-packages/sympy/solvers/inequalities.py new file mode 100644 index 0000000000000000000000000000000000000000..f50f7a7572ff25e8f48a4214d7b0c5ec6b5b35f3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/inequalities.py @@ -0,0 +1,986 @@ +"""Tools for solving inequalities and systems of inequalities. """ +import itertools + +from sympy.calculus.util import (continuous_domain, periodicity, + function_range) +from sympy.core import sympify +from sympy.core.exprtools import factor_terms +from sympy.core.relational import Relational, Lt, Ge, Eq +from sympy.core.symbol import Symbol, Dummy +from sympy.sets.sets import Interval, FiniteSet, Union, Intersection +from sympy.core.singleton import S +from sympy.core.function import expand_mul +from sympy.functions.elementary.complexes import Abs +from sympy.logic import And +from sympy.polys import Poly, PolynomialError, parallel_poly_from_expr +from sympy.polys.polyutils import _nsort +from sympy.solvers.solveset import solvify, solveset +from sympy.utilities.iterables import sift, iterable +from sympy.utilities.misc import filldedent + + +def solve_poly_inequality(poly, rel): + """Solve a polynomial inequality with rational coefficients. + + Examples + ======== + + >>> from sympy import solve_poly_inequality, Poly + >>> from sympy.abc import x + + >>> solve_poly_inequality(Poly(x, x, domain='ZZ'), '==') + [{0}] + + >>> solve_poly_inequality(Poly(x**2 - 1, x, domain='ZZ'), '!=') + [Interval.open(-oo, -1), Interval.open(-1, 1), Interval.open(1, oo)] + + >>> solve_poly_inequality(Poly(x**2 - 1, x, domain='ZZ'), '==') + [{-1}, {1}] + + See Also + ======== + solve_poly_inequalities + """ + if not isinstance(poly, Poly): + raise ValueError( + 'For efficiency reasons, `poly` should be a Poly instance') + if poly.as_expr().is_number: + t = Relational(poly.as_expr(), 0, rel) + if t is S.true: + return [S.Reals] + elif t is S.false: + return [S.EmptySet] + else: + raise NotImplementedError( + "could not determine truth value of %s" % t) + + reals, intervals = poly.real_roots(multiple=False), [] + + if rel == '==': + for root, _ in reals: + interval = Interval(root, root) + intervals.append(interval) + elif rel == '!=': + left = S.NegativeInfinity + + for right, _ in reals + [(S.Infinity, 1)]: + interval = Interval(left, right, True, True) + intervals.append(interval) + left = right + else: + if poly.LC() > 0: + sign = +1 + else: + sign = -1 + + eq_sign, equal = None, False + + if rel == '>': + eq_sign = +1 + elif rel == '<': + eq_sign = -1 + elif rel == '>=': + eq_sign, equal = +1, True + elif rel == '<=': + eq_sign, equal = -1, True + else: + raise ValueError("'%s' is not a valid relation" % rel) + + right, right_open = S.Infinity, True + + for left, multiplicity in reversed(reals): + if multiplicity % 2: + if sign == eq_sign: + intervals.insert( + 0, Interval(left, right, not equal, right_open)) + + sign, right, right_open = -sign, left, not equal + else: + if sign == eq_sign and not equal: + intervals.insert( + 0, Interval(left, right, True, right_open)) + right, right_open = left, True + elif sign != eq_sign and equal: + intervals.insert(0, Interval(left, left)) + + if sign == eq_sign: + intervals.insert( + 0, Interval(S.NegativeInfinity, right, True, right_open)) + + return intervals + + +def solve_poly_inequalities(polys): + """Solve polynomial inequalities with rational coefficients. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.solvers.inequalities import solve_poly_inequalities + >>> from sympy.abc import x + >>> solve_poly_inequalities((( + ... Poly(x**2 - 3), ">"), ( + ... Poly(-x**2 + 1), ">"))) + Union(Interval.open(-oo, -sqrt(3)), Interval.open(-1, 1), Interval.open(sqrt(3), oo)) + """ + return Union(*[s for p in polys for s in solve_poly_inequality(*p)]) + + +def solve_rational_inequalities(eqs): + """Solve a system of rational inequalities with rational coefficients. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import solve_rational_inequalities, Poly + + >>> solve_rational_inequalities([[ + ... ((Poly(-x + 1), Poly(1, x)), '>='), + ... ((Poly(-x + 1), Poly(1, x)), '<=')]]) + {1} + + >>> solve_rational_inequalities([[ + ... ((Poly(x), Poly(1, x)), '!='), + ... ((Poly(-x + 1), Poly(1, x)), '>=')]]) + Union(Interval.open(-oo, 0), Interval.Lopen(0, 1)) + + See Also + ======== + solve_poly_inequality + """ + result = S.EmptySet + + for _eqs in eqs: + if not _eqs: + continue + + global_intervals = [Interval(S.NegativeInfinity, S.Infinity)] + + for (numer, denom), rel in _eqs: + numer_intervals = solve_poly_inequality(numer*denom, rel) + denom_intervals = solve_poly_inequality(denom, '==') + + intervals = [] + + for numer_interval, global_interval in itertools.product( + numer_intervals, global_intervals): + interval = numer_interval.intersect(global_interval) + + if interval is not S.EmptySet: + intervals.append(interval) + + global_intervals = intervals + + intervals = [] + + for global_interval in global_intervals: + for denom_interval in denom_intervals: + global_interval -= denom_interval + + if global_interval is not S.EmptySet: + intervals.append(global_interval) + + global_intervals = intervals + + if not global_intervals: + break + + for interval in global_intervals: + result = result.union(interval) + + return result + + +def reduce_rational_inequalities(exprs, gen, relational=True): + """Reduce a system of rational inequalities with rational coefficients. + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.solvers.inequalities import reduce_rational_inequalities + + >>> x = Symbol('x', real=True) + + >>> reduce_rational_inequalities([[x**2 <= 0]], x) + Eq(x, 0) + + >>> reduce_rational_inequalities([[x + 2 > 0]], x) + -2 < x + >>> reduce_rational_inequalities([[(x + 2, ">")]], x) + -2 < x + >>> reduce_rational_inequalities([[x + 2]], x) + Eq(x, -2) + + This function find the non-infinite solution set so if the unknown symbol + is declared as extended real rather than real then the result may include + finiteness conditions: + + >>> y = Symbol('y', extended_real=True) + >>> reduce_rational_inequalities([[y + 2 > 0]], y) + (-2 < y) & (y < oo) + """ + exact = True + eqs = [] + solution = S.EmptySet # add pieces for each group + for _exprs in exprs: + if not _exprs: + continue + _eqs = [] + _sol = S.Reals + for expr in _exprs: + if isinstance(expr, tuple): + expr, rel = expr + else: + if expr.is_Relational: + expr, rel = expr.lhs - expr.rhs, expr.rel_op + else: + rel = '==' + + if expr is S.true: + numer, denom, rel = S.Zero, S.One, '==' + elif expr is S.false: + numer, denom, rel = S.One, S.One, '==' + else: + numer, denom = expr.together().as_numer_denom() + + try: + (numer, denom), opt = parallel_poly_from_expr( + (numer, denom), gen) + except PolynomialError: + raise PolynomialError(filldedent(''' + only polynomials and rational functions are + supported in this context. + ''')) + + if not opt.domain.is_Exact: + numer, denom, exact = numer.to_exact(), denom.to_exact(), False + + domain = opt.domain.get_exact() + + if not (domain.is_ZZ or domain.is_QQ): + expr = numer/denom + expr = Relational(expr, 0, rel) + _sol &= solve_univariate_inequality(expr, gen, relational=False) + else: + _eqs.append(((numer, denom), rel)) + + if _eqs: + _sol &= solve_rational_inequalities([_eqs]) + exclude = solve_rational_inequalities([[((d, d.one), '==') + for i in eqs for ((n, d), _) in i if d.has(gen)]]) + _sol -= exclude + + solution |= _sol + + if not exact and solution: + solution = solution.evalf() + + if relational: + solution = solution.as_relational(gen) + + return solution + + +def reduce_abs_inequality(expr, rel, gen): + """Reduce an inequality with nested absolute values. + + Examples + ======== + + >>> from sympy import reduce_abs_inequality, Abs, Symbol + >>> x = Symbol('x', real=True) + + >>> reduce_abs_inequality(Abs(x - 5) - 3, '<', x) + (2 < x) & (x < 8) + + >>> reduce_abs_inequality(Abs(x + 2)*3 - 13, '<', x) + (-19/3 < x) & (x < 7/3) + + See Also + ======== + + reduce_abs_inequalities + """ + if gen.is_extended_real is False: + raise TypeError(filldedent(''' + Cannot solve inequalities with absolute values containing + non-real variables. + ''')) + + def _bottom_up_scan(expr): + exprs = [] + + if expr.is_Add or expr.is_Mul: + op = expr.func + + for arg in expr.args: + _exprs = _bottom_up_scan(arg) + + if not exprs: + exprs = _exprs + else: + exprs = [(op(expr, _expr), conds + _conds) for (expr, conds), (_expr, _conds) in + itertools.product(exprs, _exprs)] + elif expr.is_Pow: + n = expr.exp + if not n.is_Integer: + raise ValueError("Only Integer Powers are allowed on Abs.") + + exprs.extend((expr**n, conds) for expr, conds in _bottom_up_scan(expr.base)) + elif isinstance(expr, Abs): + _exprs = _bottom_up_scan(expr.args[0]) + + for expr, conds in _exprs: + exprs.append(( expr, conds + [Ge(expr, 0)])) + exprs.append((-expr, conds + [Lt(expr, 0)])) + else: + exprs = [(expr, [])] + + return exprs + + mapping = {'<': '>', '<=': '>='} + inequalities = [] + + for expr, conds in _bottom_up_scan(expr): + if rel not in mapping.keys(): + expr = Relational( expr, 0, rel) + else: + expr = Relational(-expr, 0, mapping[rel]) + + inequalities.append([expr] + conds) + + return reduce_rational_inequalities(inequalities, gen) + + +def reduce_abs_inequalities(exprs, gen): + """Reduce a system of inequalities with nested absolute values. + + Examples + ======== + + >>> from sympy import reduce_abs_inequalities, Abs, Symbol + >>> x = Symbol('x', extended_real=True) + + >>> reduce_abs_inequalities([(Abs(3*x - 5) - 7, '<'), + ... (Abs(x + 25) - 13, '>')], x) + (-2/3 < x) & (x < 4) & (((-oo < x) & (x < -38)) | ((-12 < x) & (x < oo))) + + >>> reduce_abs_inequalities([(Abs(x - 4) + Abs(3*x - 5) - 7, '<')], x) + (1/2 < x) & (x < 4) + + See Also + ======== + + reduce_abs_inequality + """ + return And(*[ reduce_abs_inequality(expr, rel, gen) + for expr, rel in exprs ]) + + +def solve_univariate_inequality(expr, gen, relational=True, domain=S.Reals, continuous=False): + """Solves a real univariate inequality. + + Parameters + ========== + + expr : Relational + The target inequality + gen : Symbol + The variable for which the inequality is solved + relational : bool + A Relational type output is expected or not + domain : Set + The domain over which the equation is solved + continuous: bool + True if expr is known to be continuous over the given domain + (and so continuous_domain() does not need to be called on it) + + Raises + ====== + + NotImplementedError + The solution of the inequality cannot be determined due to limitation + in :func:`sympy.solvers.solveset.solvify`. + + Notes + ===== + + Currently, we cannot solve all the inequalities due to limitations in + :func:`sympy.solvers.solveset.solvify`. Also, the solution returned for trigonometric inequalities + are restricted in its periodic interval. + + See Also + ======== + + sympy.solvers.solveset.solvify: solver returning solveset solutions with solve's output API + + Examples + ======== + + >>> from sympy import solve_univariate_inequality, Symbol, sin, Interval, S + >>> x = Symbol('x') + + >>> solve_univariate_inequality(x**2 >= 4, x) + ((2 <= x) & (x < oo)) | ((-oo < x) & (x <= -2)) + + >>> solve_univariate_inequality(x**2 >= 4, x, relational=False) + Union(Interval(-oo, -2), Interval(2, oo)) + + >>> domain = Interval(0, S.Infinity) + >>> solve_univariate_inequality(x**2 >= 4, x, False, domain) + Interval(2, oo) + + >>> solve_univariate_inequality(sin(x) > 0, x, relational=False) + Interval.open(0, pi) + + """ + from sympy.solvers.solvers import denoms + + if domain.is_subset(S.Reals) is False: + raise NotImplementedError(filldedent(''' + Inequalities in the complex domain are + not supported. Try the real domain by + setting domain=S.Reals''')) + elif domain is not S.Reals: + rv = solve_univariate_inequality( + expr, gen, relational=False, continuous=continuous).intersection(domain) + if relational: + rv = rv.as_relational(gen) + return rv + else: + pass # continue with attempt to solve in Real domain + + # This keeps the function independent of the assumptions about `gen`. + # `solveset` makes sure this function is called only when the domain is + # real. + _gen = gen + _domain = domain + if gen.is_extended_real is False: + rv = S.EmptySet + return rv if not relational else rv.as_relational(_gen) + elif gen.is_extended_real is None: + gen = Dummy('gen', extended_real=True) + try: + expr = expr.xreplace({_gen: gen}) + except TypeError: + raise TypeError(filldedent(''' + When gen is real, the relational has a complex part + which leads to an invalid comparison like I < 0. + ''')) + + rv = None + + if expr is S.true: + rv = domain + + elif expr is S.false: + rv = S.EmptySet + + else: + e = expr.lhs - expr.rhs + period = periodicity(e, gen) + if period == S.Zero: + e = expand_mul(e) + const = expr.func(e, 0) + if const is S.true: + rv = domain + elif const is S.false: + rv = S.EmptySet + elif period is not None: + frange = function_range(e, gen, domain) + + rel = expr.rel_op + if rel in ('<', '<='): + if expr.func(frange.sup, 0): + rv = domain + elif not expr.func(frange.inf, 0): + rv = S.EmptySet + + elif rel in ('>', '>='): + if expr.func(frange.inf, 0): + rv = domain + elif not expr.func(frange.sup, 0): + rv = S.EmptySet + + inf, sup = domain.inf, domain.sup + if sup - inf is S.Infinity: + domain = Interval(0, period, False, True).intersect(_domain) + _domain = domain + + if rv is None: + n, d = e.as_numer_denom() + try: + if gen not in n.free_symbols and len(e.free_symbols) > 1: + raise ValueError + # this might raise ValueError on its own + # or it might give None... + solns = solvify(e, gen, domain) + if solns is None: + # in which case we raise ValueError + raise ValueError + except (ValueError, NotImplementedError): + # replace gen with generic x since it's + # univariate anyway + raise NotImplementedError(filldedent(''' + The inequality, %s, cannot be solved using + solve_univariate_inequality. + ''' % expr.subs(gen, Symbol('x')))) + + expanded_e = expand_mul(e) + def valid(x): + # this is used to see if gen=x satisfies the + # relational by substituting it into the + # expanded form and testing against 0, e.g. + # if expr = x*(x + 1) < 2 then e = x*(x + 1) - 2 + # and expanded_e = x**2 + x - 2; the test is + # whether a given value of x satisfies + # x**2 + x - 2 < 0 + # + # expanded_e, expr and gen used from enclosing scope + v = expanded_e.subs(gen, expand_mul(x)) + try: + r = expr.func(v, 0) + except TypeError: + r = S.false + if r in (S.true, S.false): + return r + if v.is_extended_real is False: + return S.false + else: + v = v.n(2) + if v.is_comparable: + return expr.func(v, 0) + # not comparable or couldn't be evaluated + raise NotImplementedError( + 'relationship did not evaluate: %s' % r) + + singularities = [] + for d in denoms(expr, gen): + singularities.extend(solvify(d, gen, domain)) + if not continuous: + domain = continuous_domain(expanded_e, gen, domain) + + include_x = '=' in expr.rel_op and expr.rel_op != '!=' + + try: + discontinuities = set(domain.boundary - + FiniteSet(domain.inf, domain.sup)) + # remove points that are not between inf and sup of domain + critical_points = FiniteSet(*(solns + singularities + list( + discontinuities))).intersection( + Interval(domain.inf, domain.sup, + domain.inf not in domain, domain.sup not in domain)) + if all(r.is_number for r in critical_points): + reals = _nsort(critical_points, separated=True)[0] + else: + sifted = sift(critical_points, lambda x: x.is_extended_real) + if sifted[None]: + # there were some roots that weren't known + # to be real + raise NotImplementedError + try: + reals = sifted[True] + if len(reals) > 1: + reals = sorted(reals) + except TypeError: + raise NotImplementedError + except NotImplementedError: + raise NotImplementedError('sorting of these roots is not supported') + + # If expr contains imaginary coefficients, only take real + # values of x for which the imaginary part is 0 + make_real = S.Reals + if (coeffI := expanded_e.coeff(S.ImaginaryUnit)) != S.Zero: + check = True + im_sol = FiniteSet() + try: + a = solveset(coeffI, gen, domain) + if not isinstance(a, Interval): + for z in a: + if z not in singularities and valid(z) and z.is_extended_real: + im_sol += FiniteSet(z) + else: + start, end = a.inf, a.sup + for z in _nsort(critical_points + FiniteSet(end)): + valid_start = valid(start) + if start != end: + valid_z = valid(z) + pt = _pt(start, z) + if pt not in singularities and pt.is_extended_real and valid(pt): + if valid_start and valid_z: + im_sol += Interval(start, z) + elif valid_start: + im_sol += Interval.Ropen(start, z) + elif valid_z: + im_sol += Interval.Lopen(start, z) + else: + im_sol += Interval.open(start, z) + start = z + for s in singularities: + im_sol -= FiniteSet(s) + except (TypeError): + im_sol = S.Reals + check = False + + if im_sol is S.EmptySet: + raise ValueError(filldedent(''' + %s contains imaginary parts which cannot be + made 0 for any value of %s satisfying the + inequality, leading to relations like I < 0. + ''' % (expr.subs(gen, _gen), _gen))) + + make_real = make_real.intersect(im_sol) + + sol_sets = [S.EmptySet] + + start = domain.inf + if start in domain and valid(start) and start.is_finite: + sol_sets.append(FiniteSet(start)) + + for x in reals: + end = x + + if valid(_pt(start, end)): + sol_sets.append(Interval(start, end, True, True)) + + if x in singularities: + singularities.remove(x) + else: + if x in discontinuities: + discontinuities.remove(x) + _valid = valid(x) + else: # it's a solution + _valid = include_x + if _valid: + sol_sets.append(FiniteSet(x)) + + start = end + + end = domain.sup + if end in domain and valid(end) and end.is_finite: + sol_sets.append(FiniteSet(end)) + + if valid(_pt(start, end)): + sol_sets.append(Interval.open(start, end)) + + if coeffI != S.Zero and check: + rv = (make_real).intersect(_domain) + else: + rv = Intersection( + (Union(*sol_sets)), make_real, _domain).subs(gen, _gen) + + return rv if not relational else rv.as_relational(_gen) + + +def _pt(start, end): + """Return a point between start and end""" + if not start.is_infinite and not end.is_infinite: + pt = (start + end)/2 + elif start.is_infinite and end.is_infinite: + pt = S.Zero + else: + if (start.is_infinite and start.is_extended_positive is None or + end.is_infinite and end.is_extended_positive is None): + raise ValueError('cannot proceed with unsigned infinite values') + if (end.is_infinite and end.is_extended_negative or + start.is_infinite and start.is_extended_positive): + start, end = end, start + # if possible, use a multiple of self which has + # better behavior when checking assumptions than + # an expression obtained by adding or subtracting 1 + if end.is_infinite: + if start.is_extended_positive: + pt = start*2 + elif start.is_extended_negative: + pt = start*S.Half + else: + pt = start + 1 + elif start.is_infinite: + if end.is_extended_positive: + pt = end*S.Half + elif end.is_extended_negative: + pt = end*2 + else: + pt = end - 1 + return pt + + +def _solve_inequality(ie, s, linear=False): + """Return the inequality with s isolated on the left, if possible. + If the relationship is non-linear, a solution involving And or Or + may be returned. False or True are returned if the relationship + is never True or always True, respectively. + + If `linear` is True (default is False) an `s`-dependent expression + will be isolated on the left, if possible + but it will not be solved for `s` unless the expression is linear + in `s`. Furthermore, only "safe" operations which do not change the + sense of the relationship are applied: no division by an unsigned + value is attempted unless the relationship involves Eq or Ne and + no division by a value not known to be nonzero is ever attempted. + + Examples + ======== + + >>> from sympy import Eq, Symbol + >>> from sympy.solvers.inequalities import _solve_inequality as f + >>> from sympy.abc import x, y + + For linear expressions, the symbol can be isolated: + + >>> f(x - 2 < 0, x) + x < 2 + >>> f(-x - 6 < x, x) + x > -3 + + Sometimes nonlinear relationships will be False + + >>> f(x**2 + 4 < 0, x) + False + + Or they may involve more than one region of values: + + >>> f(x**2 - 4 < 0, x) + (-2 < x) & (x < 2) + + To restrict the solution to a relational, set linear=True + and only the x-dependent portion will be isolated on the left: + + >>> f(x**2 - 4 < 0, x, linear=True) + x**2 < 4 + + Division of only nonzero quantities is allowed, so x cannot + be isolated by dividing by y: + + >>> y.is_nonzero is None # it is unknown whether it is 0 or not + True + >>> f(x*y < 1, x) + x*y < 1 + + And while an equality (or inequality) still holds after dividing by a + non-zero quantity + + >>> nz = Symbol('nz', nonzero=True) + >>> f(Eq(x*nz, 1), x) + Eq(x, 1/nz) + + the sign must be known for other inequalities involving > or <: + + >>> f(x*nz <= 1, x) + nz*x <= 1 + >>> p = Symbol('p', positive=True) + >>> f(x*p <= 1, x) + x <= 1/p + + When there are denominators in the original expression that + are removed by expansion, conditions for them will be returned + as part of the result: + + >>> f(x < x*(2/x - 1), x) + (x < 1) & Ne(x, 0) + """ + from sympy.solvers.solvers import denoms + if s not in ie.free_symbols: + return ie + if ie.rhs == s: + ie = ie.reversed + if ie.lhs == s and s not in ie.rhs.free_symbols: + return ie + + def classify(ie, s, i): + # return True or False if ie evaluates when substituting s with + # i else None (if unevaluated) or NaN (when there is an error + # in evaluating) + try: + v = ie.subs(s, i) + if v is S.NaN: + return v + elif v not in (True, False): + return + return v + except TypeError: + return S.NaN + + rv = None + oo = S.Infinity + expr = ie.lhs - ie.rhs + try: + p = Poly(expr, s) + if p.degree() == 0: + rv = ie.func(p.as_expr(), 0) + elif not linear and p.degree() > 1: + # handle in except clause + raise NotImplementedError + except (PolynomialError, NotImplementedError): + if not linear: + try: + rv = reduce_rational_inequalities([[ie]], s) + except PolynomialError: + rv = solve_univariate_inequality(ie, s) + # remove restrictions wrt +/-oo that may have been + # applied when using sets to simplify the relationship + okoo = classify(ie, s, oo) + if okoo is S.true and classify(rv, s, oo) is S.false: + rv = rv.subs(s < oo, True) + oknoo = classify(ie, s, -oo) + if (oknoo is S.true and + classify(rv, s, -oo) is S.false): + rv = rv.subs(-oo < s, True) + rv = rv.subs(s > -oo, True) + if rv is S.true: + rv = (s <= oo) if okoo is S.true else (s < oo) + if oknoo is not S.true: + rv = And(-oo < s, rv) + else: + p = Poly(expr) + + conds = [] + if rv is None: + e = p.as_expr() # this is in expanded form + # Do a safe inversion of e, moving non-s terms + # to the rhs and dividing by a nonzero factor if + # the relational is Eq/Ne; for other relationals + # the sign must also be positive or negative + rhs = 0 + b, ax = e.as_independent(s, as_Add=True) + e -= b + rhs -= b + ef = factor_terms(e) + a, e = ef.as_independent(s, as_Add=False) + if (a.is_zero != False or # don't divide by potential 0 + a.is_negative == + a.is_positive is None and # if sign is not known then + ie.rel_op not in ('!=', '==')): # reject if not Eq/Ne + e = ef + a = S.One + rhs /= a + if a.is_positive: + rv = ie.func(e, rhs) + else: + rv = ie.reversed.func(e, rhs) + + # return conditions under which the value is + # valid, too. + beginning_denoms = denoms(ie.lhs) | denoms(ie.rhs) + current_denoms = denoms(rv) + for d in beginning_denoms - current_denoms: + c = _solve_inequality(Eq(d, 0), s, linear=linear) + if isinstance(c, Eq) and c.lhs == s: + if classify(rv, s, c.rhs) is S.true: + # rv is permitting this value but it shouldn't + conds.append(~c) + for i in (-oo, oo): + if (classify(rv, s, i) is S.true and + classify(ie, s, i) is not S.true): + conds.append(s < i if i is oo else i < s) + + conds.append(rv) + return And(*conds) + + +def _reduce_inequalities(inequalities, symbols): + # helper for reduce_inequalities + + poly_part, abs_part = {}, {} + other = [] + + for inequality in inequalities: + + expr, rel = inequality.lhs, inequality.rel_op # rhs is 0 + + # check for gens using atoms which is more strict than free_symbols to + # guard against EX domain which won't be handled by + # reduce_rational_inequalities + gens = expr.atoms(Symbol) + + if len(gens) == 1: + gen = gens.pop() + else: + common = expr.free_symbols & symbols + if len(common) == 1: + gen = common.pop() + other.append(_solve_inequality(Relational(expr, 0, rel), gen)) + continue + else: + raise NotImplementedError(filldedent(''' + inequality has more than one symbol of interest. + ''')) + + if expr.is_polynomial(gen): + poly_part.setdefault(gen, []).append((expr, rel)) + else: + components = expr.find(lambda u: + u.has(gen) and ( + u.is_Function or u.is_Pow and not u.exp.is_Integer)) + if components and all(isinstance(i, Abs) for i in components): + abs_part.setdefault(gen, []).append((expr, rel)) + else: + other.append(_solve_inequality(Relational(expr, 0, rel), gen)) + + poly_reduced = [reduce_rational_inequalities([exprs], gen) for gen, exprs in poly_part.items()] + abs_reduced = [reduce_abs_inequalities(exprs, gen) for gen, exprs in abs_part.items()] + + return And(*(poly_reduced + abs_reduced + other)) + + +def reduce_inequalities(inequalities, symbols=[]): + """Reduce a system of inequalities with rational coefficients. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import reduce_inequalities + + >>> reduce_inequalities(0 <= x + 3, []) + (-3 <= x) & (x < oo) + + >>> reduce_inequalities(0 <= x + y*2 - 1, [x]) + (x < oo) & (x >= 1 - 2*y) + """ + if not iterable(inequalities): + inequalities = [inequalities] + inequalities = [sympify(i) for i in inequalities] + + gens = set().union(*[i.free_symbols for i in inequalities]) + + if not iterable(symbols): + symbols = [symbols] + symbols = (set(symbols) or gens) & gens + if any(i.is_extended_real is False for i in symbols): + raise TypeError(filldedent(''' + inequalities cannot contain symbols that are not real. + ''')) + + # make vanilla symbol real + recast = {i: Dummy(i.name, extended_real=True) + for i in gens if i.is_extended_real is None} + inequalities = [i.xreplace(recast) for i in inequalities] + symbols = {i.xreplace(recast) for i in symbols} + + # prefilter + keep = [] + for i in inequalities: + if isinstance(i, Relational): + i = i.func(i.lhs.as_expr() - i.rhs.as_expr(), 0) + elif i not in (True, False): + i = Eq(i, 0) + if i == True: + continue + elif i == False: + return S.false + if i.lhs.is_number: + raise NotImplementedError( + "could not determine truth value of %s" % i) + keep.append(i) + inequalities = keep + del keep + + # solve system + rv = _reduce_inequalities(inequalities, symbols) + + # restore original symbols and return + return rv.xreplace({v: k for k, v in recast.items()}) diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/pde.py b/.venv/lib/python3.13/site-packages/sympy/solvers/pde.py new file mode 100644 index 0000000000000000000000000000000000000000..791ac67ae681ea952ea6e1dabacb7220d1843ebc --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/pde.py @@ -0,0 +1,966 @@ +""" +This module contains pdsolve() and different helper functions that it +uses. It is heavily inspired by the ode module and hence the basic +infrastructure remains the same. + +**Functions in this module** + + These are the user functions in this module: + + - pdsolve() - Solves PDE's + - classify_pde() - Classifies PDEs into possible hints for dsolve(). + - pde_separate() - Separate variables in partial differential equation either by + additive or multiplicative separation approach. + + These are the helper functions in this module: + + - pde_separate_add() - Helper function for searching additive separable solutions. + - pde_separate_mul() - Helper function for searching multiplicative + separable solutions. + +**Currently implemented solver methods** + +The following methods are implemented for solving partial differential +equations. See the docstrings of the various pde_hint() functions for +more information on each (run help(pde)): + + - 1st order linear homogeneous partial differential equations + with constant coefficients. + - 1st order linear general partial differential equations + with constant coefficients. + - 1st order linear partial differential equations with + variable coefficients. + +""" +from functools import reduce + +from itertools import combinations_with_replacement +from sympy.simplify import simplify # type: ignore +from sympy.core import Add, S +from sympy.core.function import Function, expand, AppliedUndef, Subs +from sympy.core.relational import Equality, Eq +from sympy.core.symbol import Symbol, Wild, symbols +from sympy.functions import exp +from sympy.integrals.integrals import Integral, integrate +from sympy.utilities.iterables import has_dups, is_sequence +from sympy.utilities.misc import filldedent + +from sympy.solvers.deutils import _preprocess, ode_order, _desolve +from sympy.solvers.solvers import solve +from sympy.simplify.radsimp import collect + +import operator + + +allhints = ( + "1st_linear_constant_coeff_homogeneous", + "1st_linear_constant_coeff", + "1st_linear_constant_coeff_Integral", + "1st_linear_variable_coeff" + ) + + +def pdsolve(eq, func=None, hint='default', dict=False, solvefun=None, **kwargs): + """ + Solves any (supported) kind of partial differential equation. + + **Usage** + + pdsolve(eq, f(x,y), hint) -> Solve partial differential equation + eq for function f(x,y), using method hint. + + **Details** + + ``eq`` can be any supported partial differential equation (see + the pde docstring for supported methods). This can either + be an Equality, or an expression, which is assumed to be + equal to 0. + + ``f(x,y)`` is a function of two variables whose derivatives in that + variable make up the partial differential equation. In many + cases it is not necessary to provide this; it will be autodetected + (and an error raised if it could not be detected). + + ``hint`` is the solving method that you want pdsolve to use. Use + classify_pde(eq, f(x,y)) to get all of the possible hints for + a PDE. The default hint, 'default', will use whatever hint + is returned first by classify_pde(). See Hints below for + more options that you can use for hint. + + ``solvefun`` is the convention used for arbitrary functions returned + by the PDE solver. If not set by the user, it is set by default + to be F. + + **Hints** + + Aside from the various solving methods, there are also some + meta-hints that you can pass to pdsolve(): + + "default": + This uses whatever hint is returned first by + classify_pde(). This is the default argument to + pdsolve(). + + "all": + To make pdsolve apply all relevant classification hints, + use pdsolve(PDE, func, hint="all"). This will return a + dictionary of hint:solution terms. If a hint causes + pdsolve to raise the NotImplementedError, value of that + hint's key will be the exception object raised. The + dictionary will also include some special keys: + + - order: The order of the PDE. See also ode_order() in + deutils.py + - default: The solution that would be returned by + default. This is the one produced by the hint that + appears first in the tuple returned by classify_pde(). + + "all_Integral": + This is the same as "all", except if a hint also has a + corresponding "_Integral" hint, it only returns the + "_Integral" hint. This is useful if "all" causes + pdsolve() to hang because of a difficult or impossible + integral. This meta-hint will also be much faster than + "all", because integrate() is an expensive routine. + + See also the classify_pde() docstring for more info on hints, + and the pde docstring for a list of all supported hints. + + **Tips** + - You can declare the derivative of an unknown function this way: + + >>> from sympy import Function, Derivative + >>> from sympy.abc import x, y # x and y are the independent variables + >>> f = Function("f")(x, y) # f is a function of x and y + >>> # fx will be the partial derivative of f with respect to x + >>> fx = Derivative(f, x) + >>> # fy will be the partial derivative of f with respect to y + >>> fy = Derivative(f, y) + + - See test_pde.py for many tests, which serves also as a set of + examples for how to use pdsolve(). + - pdsolve always returns an Equality class (except for the case + when the hint is "all" or "all_Integral"). Note that it is not possible + to get an explicit solution for f(x, y) as in the case of ODE's + - Do help(pde.pde_hintname) to get help more information on a + specific hint + + + Examples + ======== + + >>> from sympy.solvers.pde import pdsolve + >>> from sympy import Function, Eq + >>> from sympy.abc import x, y + >>> f = Function('f') + >>> u = f(x, y) + >>> ux = u.diff(x) + >>> uy = u.diff(y) + >>> eq = Eq(1 + (2*(ux/u)) + (3*(uy/u)), 0) + >>> pdsolve(eq) + Eq(f(x, y), F(3*x - 2*y)*exp(-2*x/13 - 3*y/13)) + + """ + + if not solvefun: + solvefun = Function('F') + + # See the docstring of _desolve for more details. + hints = _desolve(eq, func=func, hint=hint, simplify=True, + type='pde', **kwargs) + eq = hints.pop('eq', False) + all_ = hints.pop('all', False) + + if all_: + # TODO : 'best' hint should be implemented when adequate + # number of hints are added. + pdedict = {} + failed_hints = {} + gethints = classify_pde(eq, dict=True) + pdedict.update({'order': gethints['order'], + 'default': gethints['default']}) + for hint in hints: + try: + rv = _helper_simplify(eq, hint, hints[hint]['func'], + hints[hint]['order'], hints[hint][hint], solvefun) + except NotImplementedError as detail: + failed_hints[hint] = detail + else: + pdedict[hint] = rv + pdedict.update(failed_hints) + return pdedict + + else: + return _helper_simplify(eq, hints['hint'], hints['func'], + hints['order'], hints[hints['hint']], solvefun) + + +def _helper_simplify(eq, hint, func, order, match, solvefun): + """Helper function of pdsolve that calls the respective + pde functions to solve for the partial differential + equations. This minimizes the computation in + calling _desolve multiple times. + """ + solvefunc = globals()["pde_" + hint.removesuffix("_Integral")] + return _handle_Integral(solvefunc(eq, func, order, + match, solvefun), func, order, hint) + + +def _handle_Integral(expr, func, order, hint): + r""" + Converts a solution with integrals in it into an actual solution. + + Simplifies the integral mainly using doit() + """ + if hint.endswith("_Integral"): + return expr + + elif hint == "1st_linear_constant_coeff": + return simplify(expr.doit()) + + else: + return expr + + +def classify_pde(eq, func=None, dict=False, *, prep=True, **kwargs): + """ + Returns a tuple of possible pdsolve() classifications for a PDE. + + The tuple is ordered so that first item is the classification that + pdsolve() uses to solve the PDE by default. In general, + classifications near the beginning of the list will produce + better solutions faster than those near the end, though there are + always exceptions. To make pdsolve use a different classification, + use pdsolve(PDE, func, hint=). See also the pdsolve() + docstring for different meta-hints you can use. + + If ``dict`` is true, classify_pde() will return a dictionary of + hint:match expression terms. This is intended for internal use by + pdsolve(). Note that because dictionaries are ordered arbitrarily, + this will most likely not be in the same order as the tuple. + + You can get help on different hints by doing help(pde.pde_hintname), + where hintname is the name of the hint without "_Integral". + + See sympy.pde.allhints or the sympy.pde docstring for a list of all + supported hints that can be returned from classify_pde. + + + Examples + ======== + + >>> from sympy.solvers.pde import classify_pde + >>> from sympy import Function, Eq + >>> from sympy.abc import x, y + >>> f = Function('f') + >>> u = f(x, y) + >>> ux = u.diff(x) + >>> uy = u.diff(y) + >>> eq = Eq(1 + (2*(ux/u)) + (3*(uy/u)), 0) + >>> classify_pde(eq) + ('1st_linear_constant_coeff_homogeneous',) + """ + + if func and len(func.args) != 2: + raise NotImplementedError("Right now only partial " + "differential equations of two variables are supported") + + if prep or func is None: + prep, func_ = _preprocess(eq, func) + if func is None: + func = func_ + + if isinstance(eq, Equality): + if eq.rhs != 0: + return classify_pde(eq.lhs - eq.rhs, func) + eq = eq.lhs + + f = func.func + x = func.args[0] + y = func.args[1] + fx = f(x,y).diff(x) + fy = f(x,y).diff(y) + + # TODO : For now pde.py uses support offered by the ode_order function + # to find the order with respect to a multi-variable function. An + # improvement could be to classify the order of the PDE on the basis of + # individual variables. + order = ode_order(eq, f(x,y)) + + # hint:matchdict or hint:(tuple of matchdicts) + # Also will contain "default": and "order":order items. + matching_hints = {'order': order} + + if not order: + if dict: + matching_hints["default"] = None + return matching_hints + return () + + eq = expand(eq) + + a = Wild('a', exclude = [f(x,y)]) + b = Wild('b', exclude = [f(x,y), fx, fy, x, y]) + c = Wild('c', exclude = [f(x,y), fx, fy, x, y]) + d = Wild('d', exclude = [f(x,y), fx, fy, x, y]) + e = Wild('e', exclude = [f(x,y), fx, fy]) + n = Wild('n', exclude = [x, y]) + # Try removing the smallest power of f(x,y) + # from the highest partial derivatives of f(x,y) + reduced_eq = eq + if eq.is_Add: + power = None + for i in set(combinations_with_replacement((x,y), order)): + coeff = eq.coeff(f(x,y).diff(*i)) + if coeff == 1: + continue + match = coeff.match(a*f(x,y)**n) + if match and match[a]: + if power is None or match[n] < power: + power = match[n] + if power: + den = f(x,y)**power + reduced_eq = Add(*[arg/den for arg in eq.args]) + + if order == 1: + reduced_eq = collect(reduced_eq, f(x, y)) + r = reduced_eq.match(b*fx + c*fy + d*f(x,y) + e) + if r: + if not r[e]: + ## Linear first-order homogeneous partial-differential + ## equation with constant coefficients + r.update({'b': b, 'c': c, 'd': d}) + matching_hints["1st_linear_constant_coeff_homogeneous"] = r + elif r[b]**2 + r[c]**2 != 0: + ## Linear first-order general partial-differential + ## equation with constant coefficients + r.update({'b': b, 'c': c, 'd': d, 'e': e}) + matching_hints["1st_linear_constant_coeff"] = r + matching_hints["1st_linear_constant_coeff_Integral"] = r + + else: + b = Wild('b', exclude=[f(x, y), fx, fy]) + c = Wild('c', exclude=[f(x, y), fx, fy]) + d = Wild('d', exclude=[f(x, y), fx, fy]) + r = reduced_eq.match(b*fx + c*fy + d*f(x,y) + e) + if r: + r.update({'b': b, 'c': c, 'd': d, 'e': e}) + matching_hints["1st_linear_variable_coeff"] = r + + # Order keys based on allhints. + rettuple = tuple(i for i in allhints if i in matching_hints) + + if dict: + # Dictionaries are ordered arbitrarily, so make note of which + # hint would come first for pdsolve(). Use an ordered dict in Py 3. + matching_hints["default"] = None + matching_hints["ordered_hints"] = rettuple + for i in allhints: + if i in matching_hints: + matching_hints["default"] = i + break + return matching_hints + return rettuple + + +def checkpdesol(pde, sol, func=None, solve_for_func=True): + """ + Checks if the given solution satisfies the partial differential + equation. + + pde is the partial differential equation which can be given in the + form of an equation or an expression. sol is the solution for which + the pde is to be checked. This can also be given in an equation or + an expression form. If the function is not provided, the helper + function _preprocess from deutils is used to identify the function. + + If a sequence of solutions is passed, the same sort of container will be + used to return the result for each solution. + + The following methods are currently being implemented to check if the + solution satisfies the PDE: + + 1. Directly substitute the solution in the PDE and check. If the + solution has not been solved for f, then it will solve for f + provided solve_for_func has not been set to False. + + If the solution satisfies the PDE, then a tuple (True, 0) is returned. + Otherwise a tuple (False, expr) where expr is the value obtained + after substituting the solution in the PDE. However if a known solution + returns False, it may be due to the inability of doit() to simplify it to zero. + + Examples + ======== + + >>> from sympy import Function, symbols + >>> from sympy.solvers.pde import checkpdesol, pdsolve + >>> x, y = symbols('x y') + >>> f = Function('f') + >>> eq = 2*f(x,y) + 3*f(x,y).diff(x) + 4*f(x,y).diff(y) + >>> sol = pdsolve(eq) + >>> assert checkpdesol(eq, sol)[0] + >>> eq = x*f(x,y) + f(x,y).diff(x) + >>> checkpdesol(eq, sol) + (False, (x*F(4*x - 3*y) - 6*F(4*x - 3*y)/25 + 4*Subs(Derivative(F(_xi_1), _xi_1), _xi_1, 4*x - 3*y))*exp(-6*x/25 - 8*y/25)) + """ + + # Converting the pde into an equation + if not isinstance(pde, Equality): + pde = Eq(pde, 0) + + # If no function is given, try finding the function present. + if func is None: + try: + _, func = _preprocess(pde.lhs) + except ValueError: + funcs = [s.atoms(AppliedUndef) for s in ( + sol if is_sequence(sol, set) else [sol])] + funcs = set().union(funcs) + if len(funcs) != 1: + raise ValueError( + 'must pass func arg to checkpdesol for this case.') + func = funcs.pop() + + # If the given solution is in the form of a list or a set + # then return a list or set of tuples. + if is_sequence(sol, set): + return type(sol)([checkpdesol( + pde, i, func=func, + solve_for_func=solve_for_func) for i in sol]) + + # Convert solution into an equation + if not isinstance(sol, Equality): + sol = Eq(func, sol) + elif sol.rhs == func: + sol = sol.reversed + + # Try solving for the function + solved = sol.lhs == func and not sol.rhs.has(func) + if solve_for_func and not solved: + solved = solve(sol, func) + if solved: + if len(solved) == 1: + return checkpdesol(pde, Eq(func, solved[0]), + func=func, solve_for_func=False) + else: + return checkpdesol(pde, [Eq(func, t) for t in solved], + func=func, solve_for_func=False) + + # try direct substitution of the solution into the PDE and simplify + if sol.lhs == func: + pde = pde.lhs - pde.rhs + s = simplify(pde.subs(func, sol.rhs).doit()) + return s is S.Zero, s + + raise NotImplementedError(filldedent(''' + Unable to test if %s is a solution to %s.''' % (sol, pde))) + + + +def pde_1st_linear_constant_coeff_homogeneous(eq, func, order, match, solvefun): + r""" + Solves a first order linear homogeneous + partial differential equation with constant coefficients. + + The general form of this partial differential equation is + + .. math:: a \frac{\partial f(x,y)}{\partial x} + + b \frac{\partial f(x,y)}{\partial y} + c f(x,y) = 0 + + where `a`, `b` and `c` are constants. + + The general solution is of the form: + + .. math:: + f(x, y) = F(- a y + b x ) e^{- \frac{c (a x + b y)}{a^2 + b^2}} + + and can be found in SymPy with ``pdsolve``:: + + >>> from sympy.solvers import pdsolve + >>> from sympy.abc import x, y, a, b, c + >>> from sympy import Function, pprint + >>> f = Function('f') + >>> u = f(x,y) + >>> ux = u.diff(x) + >>> uy = u.diff(y) + >>> genform = a*ux + b*uy + c*u + >>> pprint(genform) + d d + a*--(f(x, y)) + b*--(f(x, y)) + c*f(x, y) + dx dy + + >>> pprint(pdsolve(genform)) + -c*(a*x + b*y) + --------------- + 2 2 + a + b + f(x, y) = F(-a*y + b*x)*e + + Examples + ======== + + >>> from sympy import pdsolve + >>> from sympy import Function, pprint + >>> from sympy.abc import x,y + >>> f = Function('f') + >>> pdsolve(f(x,y) + f(x,y).diff(x) + f(x,y).diff(y)) + Eq(f(x, y), F(x - y)*exp(-x/2 - y/2)) + >>> pprint(pdsolve(f(x,y) + f(x,y).diff(x) + f(x,y).diff(y))) + x y + - - - - + 2 2 + f(x, y) = F(x - y)*e + + References + ========== + + - Viktor Grigoryan, "Partial Differential Equations" + Math 124A - Fall 2010, pp.7 + + """ + # TODO : For now homogeneous first order linear PDE's having + # two variables are implemented. Once there is support for + # solving systems of ODE's, this can be extended to n variables. + + f = func.func + x = func.args[0] + y = func.args[1] + b = match[match['b']] + c = match[match['c']] + d = match[match['d']] + return Eq(f(x,y), exp(-S(d)/(b**2 + c**2)*(b*x + c*y))*solvefun(c*x - b*y)) + + +def pde_1st_linear_constant_coeff(eq, func, order, match, solvefun): + r""" + Solves a first order linear partial differential equation + with constant coefficients. + + The general form of this partial differential equation is + + .. math:: a \frac{\partial f(x,y)}{\partial x} + + b \frac{\partial f(x,y)}{\partial y} + + c f(x,y) = G(x,y) + + where `a`, `b` and `c` are constants and `G(x, y)` can be an arbitrary + function in `x` and `y`. + + The general solution of the PDE is: + + .. math:: + f(x, y) = \left. \left[F(\eta) + \frac{1}{a^2 + b^2} + \int\limits^{a x + b y} G\left(\frac{a \xi + b \eta}{a^2 + b^2}, + \frac{- a \eta + b \xi}{a^2 + b^2} \right) + e^{\frac{c \xi}{a^2 + b^2}}\, d\xi\right] + e^{- \frac{c \xi}{a^2 + b^2}} + \right|_{\substack{\eta=- a y + b x\\ \xi=a x + b y }}\, , + + where `F(\eta)` is an arbitrary single-valued function. The solution + can be found in SymPy with ``pdsolve``:: + + >>> from sympy.solvers import pdsolve + >>> from sympy.abc import x, y, a, b, c + >>> from sympy import Function, pprint + >>> f = Function('f') + >>> G = Function('G') + >>> u = f(x, y) + >>> ux = u.diff(x) + >>> uy = u.diff(y) + >>> genform = a*ux + b*uy + c*u - G(x,y) + >>> pprint(genform) + d d + a*--(f(x, y)) + b*--(f(x, y)) + c*f(x, y) - G(x, y) + dx dy + >>> pprint(pdsolve(genform, hint='1st_linear_constant_coeff_Integral')) + // a*x + b*y \ \| + || / | || + || | | || + || | c*xi | || + || | ------- | || + || | 2 2 | || + || | /a*xi + b*eta -a*eta + b*xi\ a + b | || + || | G|------------, -------------|*e d(xi)| || + || | | 2 2 2 2 | | || + || | \ a + b a + b / | -c*xi || + || | | -------|| + || / | 2 2|| + || | a + b || + f(x, y) = ||F(eta) + -------------------------------------------------------|*e || + || 2 2 | || + \\ a + b / /|eta=-a*y + b*x, xi=a*x + b*y + + Examples + ======== + + >>> from sympy.solvers.pde import pdsolve + >>> from sympy import Function, pprint, exp + >>> from sympy.abc import x,y + >>> f = Function('f') + >>> eq = -2*f(x,y).diff(x) + 4*f(x,y).diff(y) + 5*f(x,y) - exp(x + 3*y) + >>> pdsolve(eq) + Eq(f(x, y), (F(4*x + 2*y)*exp(x/2) + exp(x + 4*y)/15)*exp(-y)) + + References + ========== + + - Viktor Grigoryan, "Partial Differential Equations" + Math 124A - Fall 2010, pp.7 + + """ + + # TODO : For now homogeneous first order linear PDE's having + # two variables are implemented. Once there is support for + # solving systems of ODE's, this can be extended to n variables. + xi, eta = symbols("xi eta") + f = func.func + x = func.args[0] + y = func.args[1] + b = match[match['b']] + c = match[match['c']] + d = match[match['d']] + e = -match[match['e']] + expterm = exp(-S(d)/(b**2 + c**2)*xi) + functerm = solvefun(eta) + solvedict = solve((b*x + c*y - xi, c*x - b*y - eta), x, y) + # Integral should remain as it is in terms of xi, + # doit() should be done in _handle_Integral. + genterm = (1/S(b**2 + c**2))*Integral( + (1/expterm*e).subs(solvedict), (xi, b*x + c*y)) + return Eq(f(x,y), Subs(expterm*(functerm + genterm), + (eta, xi), (c*x - b*y, b*x + c*y))) + + +def pde_1st_linear_variable_coeff(eq, func, order, match, solvefun): + r""" + Solves a first order linear partial differential equation + with variable coefficients. The general form of this partial + differential equation is + + .. math:: a(x, y) \frac{\partial f(x, y)}{\partial x} + + b(x, y) \frac{\partial f(x, y)}{\partial y} + + c(x, y) f(x, y) = G(x, y) + + where `a(x, y)`, `b(x, y)`, `c(x, y)` and `G(x, y)` are arbitrary + functions in `x` and `y`. This PDE is converted into an ODE by + making the following transformation: + + 1. `\xi` as `x` + + 2. `\eta` as the constant in the solution to the differential + equation `\frac{dy}{dx} = -\frac{b}{a}` + + Making the previous substitutions reduces it to the linear ODE + + .. math:: a(\xi, \eta)\frac{du}{d\xi} + c(\xi, \eta)u - G(\xi, \eta) = 0 + + which can be solved using ``dsolve``. + + >>> from sympy.abc import x, y + >>> from sympy import Function, pprint + >>> a, b, c, G, f= [Function(i) for i in ['a', 'b', 'c', 'G', 'f']] + >>> u = f(x,y) + >>> ux = u.diff(x) + >>> uy = u.diff(y) + >>> genform = a(x, y)*u + b(x, y)*ux + c(x, y)*uy - G(x,y) + >>> pprint(genform) + d d + -G(x, y) + a(x, y)*f(x, y) + b(x, y)*--(f(x, y)) + c(x, y)*--(f(x, y)) + dx dy + + + Examples + ======== + + >>> from sympy.solvers.pde import pdsolve + >>> from sympy import Function, pprint + >>> from sympy.abc import x,y + >>> f = Function('f') + >>> eq = x*(u.diff(x)) - y*(u.diff(y)) + y**2*u - y**2 + >>> pdsolve(eq) + Eq(f(x, y), F(x*y)*exp(y**2/2) + 1) + + References + ========== + + - Viktor Grigoryan, "Partial Differential Equations" + Math 124A - Fall 2010, pp.7 + + """ + from sympy.solvers.ode import dsolve + + eta = symbols("eta") + f = func.func + x = func.args[0] + y = func.args[1] + b = match[match['b']] + c = match[match['c']] + d = match[match['d']] + e = -match[match['e']] + + + if not d: + # To deal with cases like b*ux = e or c*uy = e + if not (b and c): + if c: + try: + tsol = integrate(e/c, y) + except NotImplementedError: + raise NotImplementedError("Unable to find a solution" + " due to inability of integrate") + else: + return Eq(f(x,y), solvefun(x) + tsol) + if b: + try: + tsol = integrate(e/b, x) + except NotImplementedError: + raise NotImplementedError("Unable to find a solution" + " due to inability of integrate") + else: + return Eq(f(x,y), solvefun(y) + tsol) + + if not c: + # To deal with cases when c is 0, a simpler method is used. + # The PDE reduces to b*(u.diff(x)) + d*u = e, which is a linear ODE in x + plode = f(x).diff(x)*b + d*f(x) - e + sol = dsolve(plode, f(x)) + syms = sol.free_symbols - plode.free_symbols - {x, y} + rhs = _simplify_variable_coeff(sol.rhs, syms, solvefun, y) + return Eq(f(x, y), rhs) + + if not b: + # To deal with cases when b is 0, a simpler method is used. + # The PDE reduces to c*(u.diff(y)) + d*u = e, which is a linear ODE in y + plode = f(y).diff(y)*c + d*f(y) - e + sol = dsolve(plode, f(y)) + syms = sol.free_symbols - plode.free_symbols - {x, y} + rhs = _simplify_variable_coeff(sol.rhs, syms, solvefun, x) + return Eq(f(x, y), rhs) + + dummy = Function('d') + h = (c/b).subs(y, dummy(x)) + sol = dsolve(dummy(x).diff(x) - h, dummy(x)) + if isinstance(sol, list): + sol = sol[0] + solsym = sol.free_symbols - h.free_symbols - {x, y} + if len(solsym) == 1: + solsym = solsym.pop() + etat = (solve(sol, solsym)[0]).subs(dummy(x), y) + ysub = solve(eta - etat, y)[0] + deq = (b*(f(x).diff(x)) + d*f(x) - e).subs(y, ysub) + final = (dsolve(deq, f(x), hint='1st_linear')).rhs + if isinstance(final, list): + final = final[0] + finsyms = final.free_symbols - deq.free_symbols - {x, y} + rhs = _simplify_variable_coeff(final, finsyms, solvefun, etat) + return Eq(f(x, y), rhs) + + else: + raise NotImplementedError("Cannot solve the partial differential equation due" + " to inability of constantsimp") + + +def _simplify_variable_coeff(sol, syms, func, funcarg): + r""" + Helper function to replace constants by functions in 1st_linear_variable_coeff + """ + eta = Symbol("eta") + if len(syms) == 1: + sym = syms.pop() + final = sol.subs(sym, func(funcarg)) + + else: + for sym in syms: + final = sol.subs(sym, func(funcarg)) + + return simplify(final.subs(eta, funcarg)) + + +def pde_separate(eq, fun, sep, strategy='mul'): + """Separate variables in partial differential equation either by additive + or multiplicative separation approach. It tries to rewrite an equation so + that one of the specified variables occurs on a different side of the + equation than the others. + + :param eq: Partial differential equation + + :param fun: Original function F(x, y, z) + + :param sep: List of separated functions [X(x), u(y, z)] + + :param strategy: Separation strategy. You can choose between additive + separation ('add') and multiplicative separation ('mul') which is + default. + + Examples + ======== + + >>> from sympy import E, Eq, Function, pde_separate, Derivative as D + >>> from sympy.abc import x, t + >>> u, X, T = map(Function, 'uXT') + + >>> eq = Eq(D(u(x, t), x), E**(u(x, t))*D(u(x, t), t)) + >>> pde_separate(eq, u(x, t), [X(x), T(t)], strategy='add') + [exp(-X(x))*Derivative(X(x), x), exp(T(t))*Derivative(T(t), t)] + + >>> eq = Eq(D(u(x, t), x, 2), D(u(x, t), t, 2)) + >>> pde_separate(eq, u(x, t), [X(x), T(t)], strategy='mul') + [Derivative(X(x), (x, 2))/X(x), Derivative(T(t), (t, 2))/T(t)] + + See Also + ======== + pde_separate_add, pde_separate_mul + """ + + do_add = False + if strategy == 'add': + do_add = True + elif strategy == 'mul': + do_add = False + else: + raise ValueError('Unknown strategy: %s' % strategy) + + if isinstance(eq, Equality): + if eq.rhs != 0: + return pde_separate(Eq(eq.lhs - eq.rhs, 0), fun, sep, strategy) + else: + return pde_separate(Eq(eq, 0), fun, sep, strategy) + + if eq.rhs != 0: + raise ValueError("Value should be 0") + + # Handle arguments + orig_args = list(fun.args) + subs_args = [arg for s in sep for arg in s.args] + + if do_add: + functions = reduce(operator.add, sep) + else: + functions = reduce(operator.mul, sep) + + # Check whether variables match + if len(subs_args) != len(orig_args): + raise ValueError("Variable counts do not match") + # Check for duplicate arguments like [X(x), u(x, y)] + if has_dups(subs_args): + raise ValueError("Duplicate substitution arguments detected") + # Check whether the variables match + if set(orig_args) != set(subs_args): + raise ValueError("Arguments do not match") + + # Substitute original function with separated... + result = eq.lhs.subs(fun, functions).doit() + + # Divide by terms when doing multiplicative separation + if not do_add: + eq = 0 + for i in result.args: + eq += i/functions + result = eq + + svar = subs_args[0] + dvar = subs_args[1:] + return _separate(result, svar, dvar) + + +def pde_separate_add(eq, fun, sep): + """ + Helper function for searching additive separable solutions. + + Consider an equation of two independent variables x, y and a dependent + variable w, we look for the product of two functions depending on different + arguments: + + `w(x, y, z) = X(x) + y(y, z)` + + Examples + ======== + + >>> from sympy import E, Eq, Function, pde_separate_add, Derivative as D + >>> from sympy.abc import x, t + >>> u, X, T = map(Function, 'uXT') + + >>> eq = Eq(D(u(x, t), x), E**(u(x, t))*D(u(x, t), t)) + >>> pde_separate_add(eq, u(x, t), [X(x), T(t)]) + [exp(-X(x))*Derivative(X(x), x), exp(T(t))*Derivative(T(t), t)] + + """ + return pde_separate(eq, fun, sep, strategy='add') + + +def pde_separate_mul(eq, fun, sep): + """ + Helper function for searching multiplicative separable solutions. + + Consider an equation of two independent variables x, y and a dependent + variable w, we look for the product of two functions depending on different + arguments: + + `w(x, y, z) = X(x)*u(y, z)` + + Examples + ======== + + >>> from sympy import Function, Eq, pde_separate_mul, Derivative as D + >>> from sympy.abc import x, y + >>> u, X, Y = map(Function, 'uXY') + + >>> eq = Eq(D(u(x, y), x, 2), D(u(x, y), y, 2)) + >>> pde_separate_mul(eq, u(x, y), [X(x), Y(y)]) + [Derivative(X(x), (x, 2))/X(x), Derivative(Y(y), (y, 2))/Y(y)] + + """ + return pde_separate(eq, fun, sep, strategy='mul') + + +def _separate(eq, dep, others): + """Separate expression into two parts based on dependencies of variables.""" + + # FIRST PASS + # Extract derivatives depending our separable variable... + terms = set() + for term in eq.args: + if term.is_Mul: + for i in term.args: + if i.is_Derivative and not i.has(*others): + terms.add(term) + continue + elif term.is_Derivative and not term.has(*others): + terms.add(term) + # Find the factor that we need to divide by + div = set() + for term in terms: + ext, sep = term.expand().as_independent(dep) + # Failed? + if sep.has(*others): + return None + div.add(ext) + # FIXME: Find lcm() of all the divisors and divide with it, instead of + # current hack :( + # https://github.com/sympy/sympy/issues/4597 + if len(div) > 0: + # double sum required or some tests will fail + eq = Add(*[simplify(Add(*[term/i for i in div])) for term in eq.args]) + # SECOND PASS - separate the derivatives + div = set() + lhs = rhs = 0 + for term in eq.args: + # Check, whether we have already term with independent variable... + if not term.has(*others): + lhs += term + continue + # ...otherwise, try to separate + temp, sep = term.expand().as_independent(dep) + # Failed? + if sep.has(*others): + return None + # Extract the divisors + div.add(sep) + rhs -= term.expand() + # Do the division + fulldiv = reduce(operator.add, div) + lhs = simplify(lhs/fulldiv).expand() + rhs = simplify(rhs/fulldiv).expand() + # ...and check whether we were successful :) + if lhs.has(*others) or rhs.has(dep): + return None + return [lhs, rhs] diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/polysys.py b/.venv/lib/python3.13/site-packages/sympy/solvers/polysys.py new file mode 100644 index 0000000000000000000000000000000000000000..2edc70b36c25b986c975c33fcc57535ef0b31df2 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/polysys.py @@ -0,0 +1,872 @@ +"""Solvers of systems of polynomial equations. """ + +from __future__ import annotations + +from typing import Any +from collections.abc import Sequence, Iterable + +import itertools + +from sympy import Dummy +from sympy.core import S +from sympy.core.expr import Expr +from sympy.core.exprtools import factor_terms +from sympy.core.sorting import default_sort_key +from sympy.logic.boolalg import Boolean +from sympy.polys import Poly, groebner, roots +from sympy.polys.domains import ZZ +from sympy.polys.polyoptions import build_options +from sympy.polys.polytools import parallel_poly_from_expr, sqf_part +from sympy.polys.polyerrors import ( + ComputationFailed, + PolificationFailed, + CoercionFailed, + GeneratorsNeeded, + DomainError +) +from sympy.simplify import rcollect +from sympy.utilities import postfixes +from sympy.utilities.iterables import cartes +from sympy.utilities.misc import filldedent +from sympy.logic.boolalg import Or, And +from sympy.core.relational import Eq + + +class SolveFailed(Exception): + """Raised when solver's conditions were not met. """ + + +def solve_poly_system(seq, *gens, strict=False, **args): + """ + Return a list of solutions for the system of polynomial equations + or else None. + + Parameters + ========== + + seq: a list/tuple/set + Listing all the equations that are needed to be solved + gens: generators + generators of the equations in seq for which we want the + solutions + strict: a boolean (default is False) + if strict is True, NotImplementedError will be raised if + the solution is known to be incomplete (which can occur if + not all solutions are expressible in radicals) + args: Keyword arguments + Special options for solving the equations. + + + Returns + ======= + + List[Tuple] + a list of tuples with elements being solutions for the + symbols in the order they were passed as gens + None + None is returned when the computed basis contains only the ground. + + Examples + ======== + + >>> from sympy import solve_poly_system + >>> from sympy.abc import x, y + + >>> solve_poly_system([x*y - 2*y, 2*y**2 - x**2], x, y) + [(0, 0), (2, -sqrt(2)), (2, sqrt(2))] + + >>> solve_poly_system([x**5 - x + y**3, y**2 - 1], x, y, strict=True) + Traceback (most recent call last): + ... + UnsolvableFactorError + + """ + try: + polys, opt = parallel_poly_from_expr(seq, *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('solve_poly_system', len(seq), exc) + + if len(polys) == len(opt.gens) == 2: + f, g = polys + + if all(i <= 2 for i in f.degree_list() + g.degree_list()): + try: + return solve_biquadratic(f, g, opt) + except SolveFailed: + pass + + return solve_generic(polys, opt, strict=strict) + + +def solve_biquadratic(f, g, opt): + """Solve a system of two bivariate quadratic polynomial equations. + + Parameters + ========== + + f: a single Expr or Poly + First equation + g: a single Expr or Poly + Second Equation + opt: an Options object + For specifying keyword arguments and generators + + Returns + ======= + + List[Tuple] + a list of tuples with elements being solutions for the + symbols in the order they were passed as gens + None + None is returned when the computed basis contains only the ground. + + Examples + ======== + + >>> from sympy import Options, Poly + >>> from sympy.abc import x, y + >>> from sympy.solvers.polysys import solve_biquadratic + >>> NewOption = Options((x, y), {'domain': 'ZZ'}) + + >>> a = Poly(y**2 - 4 + x, y, x, domain='ZZ') + >>> b = Poly(y*2 + 3*x - 7, y, x, domain='ZZ') + >>> solve_biquadratic(a, b, NewOption) + [(1/3, 3), (41/27, 11/9)] + + >>> a = Poly(y + x**2 - 3, y, x, domain='ZZ') + >>> b = Poly(-y + x - 4, y, x, domain='ZZ') + >>> solve_biquadratic(a, b, NewOption) + [(7/2 - sqrt(29)/2, -sqrt(29)/2 - 1/2), (sqrt(29)/2 + 7/2, -1/2 + \ + sqrt(29)/2)] + """ + G = groebner([f, g]) + + if len(G) == 1 and G[0].is_ground: + return None + + if len(G) != 2: + raise SolveFailed + + x, y = opt.gens + p, q = G + if not p.gcd(q).is_ground: + # not 0-dimensional + raise SolveFailed + + p = Poly(p, x, expand=False) + p_roots = [rcollect(expr, y) for expr in roots(p).keys()] + + q = q.ltrim(-1) + q_roots = list(roots(q).keys()) + + solutions = [(p_root.subs(y, q_root), q_root) for q_root, p_root in + itertools.product(q_roots, p_roots)] + + return sorted(solutions, key=default_sort_key) + + +def solve_generic(polys, opt, strict=False): + """ + Solve a generic system of polynomial equations. + + Returns all possible solutions over C[x_1, x_2, ..., x_m] of a + set F = { f_1, f_2, ..., f_n } of polynomial equations, using + Groebner basis approach. For now only zero-dimensional systems + are supported, which means F can have at most a finite number + of solutions. If the basis contains only the ground, None is + returned. + + The algorithm works by the fact that, supposing G is the basis + of F with respect to an elimination order (here lexicographic + order is used), G and F generate the same ideal, they have the + same set of solutions. By the elimination property, if G is a + reduced, zero-dimensional Groebner basis, then there exists an + univariate polynomial in G (in its last variable). This can be + solved by computing its roots. Substituting all computed roots + for the last (eliminated) variable in other elements of G, new + polynomial system is generated. Applying the above procedure + recursively, a finite number of solutions can be found. + + The ability of finding all solutions by this procedure depends + on the root finding algorithms. If no solutions were found, it + means only that roots() failed, but the system is solvable. To + overcome this difficulty use numerical algorithms instead. + + Parameters + ========== + + polys: a list/tuple/set + Listing all the polynomial equations that are needed to be solved + opt: an Options object + For specifying keyword arguments and generators + strict: a boolean + If strict is True, NotImplementedError will be raised if the solution + is known to be incomplete + + Returns + ======= + + List[Tuple] + a list of tuples with elements being solutions for the + symbols in the order they were passed as gens + None + None is returned when the computed basis contains only the ground. + + References + ========== + + .. [Buchberger01] B. Buchberger, Groebner Bases: A Short + Introduction for Systems Theorists, In: R. Moreno-Diaz, + B. Buchberger, J.L. Freire, Proceedings of EUROCAST'01, + February, 2001 + + .. [Cox97] D. Cox, J. Little, D. O'Shea, Ideals, Varieties + and Algorithms, Springer, Second Edition, 1997, pp. 112 + + Raises + ======== + + NotImplementedError + If the system is not zero-dimensional (does not have a finite + number of solutions) + + UnsolvableFactorError + If ``strict`` is True and not all solution components are + expressible in radicals + + Examples + ======== + + >>> from sympy import Poly, Options + >>> from sympy.solvers.polysys import solve_generic + >>> from sympy.abc import x, y + >>> NewOption = Options((x, y), {'domain': 'ZZ'}) + + >>> a = Poly(x - y + 5, x, y, domain='ZZ') + >>> b = Poly(x + y - 3, x, y, domain='ZZ') + >>> solve_generic([a, b], NewOption) + [(-1, 4)] + + >>> a = Poly(x - 2*y + 5, x, y, domain='ZZ') + >>> b = Poly(2*x - y - 3, x, y, domain='ZZ') + >>> solve_generic([a, b], NewOption) + [(11/3, 13/3)] + + >>> a = Poly(x**2 + y, x, y, domain='ZZ') + >>> b = Poly(x + y*4, x, y, domain='ZZ') + >>> solve_generic([a, b], NewOption) + [(0, 0), (1/4, -1/16)] + + >>> a = Poly(x**5 - x + y**3, x, y, domain='ZZ') + >>> b = Poly(y**2 - 1, x, y, domain='ZZ') + >>> solve_generic([a, b], NewOption, strict=True) + Traceback (most recent call last): + ... + UnsolvableFactorError + + """ + def _is_univariate(f): + """Returns True if 'f' is univariate in its last variable. """ + for monom in f.monoms(): + if any(monom[:-1]): + return False + + return True + + def _subs_root(f, gen, zero): + """Replace generator with a root so that the result is nice. """ + p = f.as_expr({gen: zero}) + + if f.degree(gen) >= 2: + p = p.expand(deep=False) + + return p + + def _solve_reduced_system(system, gens, entry=False): + """Recursively solves reduced polynomial systems. """ + if len(system) == len(gens) == 1: + # the below line will produce UnsolvableFactorError if + # strict=True and the solution from `roots` is incomplete + zeros = list(roots(system[0], gens[-1], strict=strict).keys()) + return [(zero,) for zero in zeros] + + basis = groebner(system, gens, polys=True) + + if len(basis) == 1 and basis[0].is_ground: + if not entry: + return [] + else: + return None + + univariate = list(filter(_is_univariate, basis)) + + if len(basis) < len(gens): + raise NotImplementedError(filldedent(''' + only zero-dimensional systems supported + (finite number of solutions) + ''')) + + if len(univariate) == 1: + f = univariate.pop() + else: + raise NotImplementedError(filldedent(''' + only zero-dimensional systems supported + (finite number of solutions) + ''')) + + gens = f.gens + gen = gens[-1] + + # the below line will produce UnsolvableFactorError if + # strict=True and the solution from `roots` is incomplete + zeros = list(roots(f.ltrim(gen), strict=strict).keys()) + + if not zeros: + return [] + + if len(basis) == 1: + return [(zero,) for zero in zeros] + + solutions = [] + + for zero in zeros: + new_system = [] + new_gens = gens[:-1] + + for b in basis[:-1]: + eq = _subs_root(b, gen, zero) + + if eq is not S.Zero: + new_system.append(eq) + + for solution in _solve_reduced_system(new_system, new_gens): + solutions.append(solution + (zero,)) + + if solutions and len(solutions[0]) != len(gens): + raise NotImplementedError(filldedent(''' + only zero-dimensional systems supported + (finite number of solutions) + ''')) + return solutions + + try: + result = _solve_reduced_system(polys, opt.gens, entry=True) + except CoercionFailed: + raise NotImplementedError + + if result is not None: + return sorted(result, key=default_sort_key) + + +def solve_triangulated(polys, *gens, **args): + """ + Solve a polynomial system using Gianni-Kalkbrenner algorithm. + + The algorithm proceeds by computing one Groebner basis in the ground + domain and then by iteratively computing polynomial factorizations in + appropriately constructed algebraic extensions of the ground domain. + + Parameters + ========== + + polys: a list/tuple/set + Listing all the equations that are needed to be solved + gens: generators + generators of the equations in polys for which we want the + solutions + args: Keyword arguments + Special options for solving the equations + + Returns + ======= + + List[Tuple] + A List of tuples. Solutions for symbols that satisfy the + equations listed in polys + + Examples + ======== + + >>> from sympy import solve_triangulated + >>> from sympy.abc import x, y, z + + >>> F = [x**2 + y + z - 1, x + y**2 + z - 1, x + y + z**2 - 1] + + >>> solve_triangulated(F, x, y, z) + [(0, 0, 1), (0, 1, 0), (1, 0, 0)] + + Using extension for algebraic solutions. + + >>> solve_triangulated(F, x, y, z, extension=True) #doctest: +NORMALIZE_WHITESPACE + [(0, 0, 1), (0, 1, 0), (1, 0, 0), + (CRootOf(x**2 + 2*x - 1, 0), CRootOf(x**2 + 2*x - 1, 0), CRootOf(x**2 + 2*x - 1, 0)), + (CRootOf(x**2 + 2*x - 1, 1), CRootOf(x**2 + 2*x - 1, 1), CRootOf(x**2 + 2*x - 1, 1))] + + References + ========== + + 1. Patrizia Gianni, Teo Mora, Algebraic Solution of System of + Polynomial Equations using Groebner Bases, AAECC-5 on Applied Algebra, + Algebraic Algorithms and Error-Correcting Codes, LNCS 356 247--257, 1989 + + """ + opt = build_options(gens, args) + + G = groebner(polys, gens, polys=True) + G = list(reversed(G)) + + extension = opt.get('extension', False) + if extension: + def _solve_univariate(f): + return [r for r, _ in f.all_roots(multiple=False, radicals=False)] + else: + domain = opt.get('domain') + + if domain is not None: + for i, g in enumerate(G): + G[i] = g.set_domain(domain) + + def _solve_univariate(f): + return list(f.ground_roots().keys()) + + f, G = G[0].ltrim(-1), G[1:] + dom = f.get_domain() + + zeros = _solve_univariate(f) + + if extension: + solutions = {((zero,), dom.algebraic_field(zero)) for zero in zeros} + else: + solutions = {((zero,), dom) for zero in zeros} + + var_seq = reversed(gens[:-1]) + vars_seq = postfixes(gens[1:]) + + for var, vars in zip(var_seq, vars_seq): + _solutions = set() + + for values, dom in solutions: + H, mapping = [], list(zip(vars, values)) + + for g in G: + _vars = (var,) + vars + + if g.has_only_gens(*_vars) and g.degree(var) != 0: + if extension: + g = g.set_domain(g.domain.unify(dom)) + h = g.ltrim(var).eval(dict(mapping)) + + if g.degree(var) == h.degree(): + H.append(h) + + p = min(H, key=lambda h: h.degree()) + zeros = _solve_univariate(p) + + for zero in zeros: + if not (zero in dom): + dom_zero = dom.algebraic_field(zero) + else: + dom_zero = dom + + _solutions.add(((zero,) + values, dom_zero)) + + solutions = _solutions + return sorted((s for s, _ in solutions), key=default_sort_key) + + +def factor_system(eqs: Sequence[Expr | complex], gens: Sequence[Expr] = (), **kwargs: Any) -> list[list[Expr]]: + """ + Factorizes a system of polynomial equations into + irreducible subsystems. + + Parameters + ========== + + eqs : list + List of expressions to be factored. + Each expression is assumed to be equal to zero. + + gens : list, optional + Generator(s) of the polynomial ring. + If not provided, all free symbols will be used. + + **kwargs : dict, optional + Same optional arguments taken by ``factor`` + + Returns + ======= + + list[list[Expr]] + A list of lists of expressions, where each sublist represents + an irreducible subsystem. When solved, each subsystem gives + one component of the solution. Only generic solutions are + returned (cases not requiring parameters to be zero). + + Examples + ======== + + >>> from sympy.solvers.polysys import factor_system, factor_system_cond + >>> from sympy.abc import x, y, a, b, c + + A simple system with multiple solutions: + + >>> factor_system([x**2 - 1, y - 1]) + [[x + 1, y - 1], [x - 1, y - 1]] + + A system with no solution: + + >>> factor_system([x, 1]) + [] + + A system where any value of the symbol(s) is a solution: + + >>> factor_system([x - x, (x + 1)**2 - (x**2 + 2*x + 1)]) + [[]] + + A system with no generic solution: + + >>> factor_system([a*x*(x-1), b*y, c], [x, y]) + [] + + If c is added to the unknowns then the system has a generic solution: + + >>> factor_system([a*x*(x-1), b*y, c], [x, y, c]) + [[x - 1, y, c], [x, y, c]] + + Alternatively :func:`factor_system_cond` can be used to get degenerate + cases as well: + + >>> factor_system_cond([a*x*(x-1), b*y, c], [x, y]) + [[x - 1, y, c], [x, y, c], [x - 1, b, c], [x, b, c], [y, a, c], [a, b, c]] + + Each of the above cases is only satisfiable in the degenerate case `c = 0`. + + The solution set of the original system represented + by eqs is the union of the solution sets of the + factorized systems. + + An empty list [] means no generic solution exists. + A list containing an empty list [[]] means any value of + the symbol(s) is a solution. + + See Also + ======== + + factor_system_cond : Returns both generic and degenerate solutions + factor_system_bool : Returns a Boolean combination representing all solutions + sympy.polys.polytools.factor : Factors a polynomial into irreducible factors + over the rational numbers + """ + + systems = _factor_system_poly_from_expr(eqs, gens, **kwargs) + systems_generic = [sys for sys in systems if not _is_degenerate(sys)] + systems_expr = [[p.as_expr() for p in system] for system in systems_generic] + return systems_expr + + +def _is_degenerate(system: list[Poly]) -> bool: + """Helper function to check if a system is degenerate""" + return any(p.is_ground for p in system) + + +def factor_system_bool(eqs: Sequence[Expr | complex], gens: Sequence[Expr] = (), **kwargs: Any) -> Boolean: + """ + Factorizes a system of polynomial equations into irreducible DNF. + + The system of expressions(eqs) is taken and a Boolean combination + of equations is returned that represents the same solution set. + The result is in disjunctive normal form (OR of ANDs). + + Parameters + ========== + + eqs : list + List of expressions to be factored. + Each expression is assumed to be equal to zero. + + gens : list, optional + Generator(s) of the polynomial ring. + If not provided, all free symbols will be used. + + **kwargs : dict, optional + Optional keyword arguments + + + Returns + ======= + + Boolean: + A Boolean combination of equations. The result is typically in + the form of a conjunction (AND) of a disjunctive normal form + with additional conditions. + + Examples + ======== + + >>> from sympy.solvers.polysys import factor_system_bool + >>> from sympy.abc import x, y, a, b, c + >>> factor_system_bool([x**2 - 1]) + Eq(x - 1, 0) | Eq(x + 1, 0) + + >>> factor_system_bool([x**2 - 1, y - 1]) + (Eq(x - 1, 0) & Eq(y - 1, 0)) | (Eq(x + 1, 0) & Eq(y - 1, 0)) + + >>> eqs = [a * (x - 1), b] + >>> factor_system_bool([a*(x - 1), b]) + (Eq(a, 0) & Eq(b, 0)) | (Eq(b, 0) & Eq(x - 1, 0)) + + >>> factor_system_bool([a*x**2 - a, b*(x + 1), c], [x]) + (Eq(c, 0) & Eq(x + 1, 0)) | (Eq(a, 0) & Eq(b, 0) & Eq(c, 0)) | (Eq(b, 0) & Eq(c, 0) & Eq(x - 1, 0)) + + >>> factor_system_bool([x**2 + 2*x + 1 - (x + 1)**2]) + True + + The result is logically equivalent to the system of equations + i.e. eqs. The function returns ``True`` when all values of + the symbol(s) is a solution and ``False`` when the system + cannot be solved. + + See Also + ======== + + factor_system : Returns factors and solvability condition separately + factor_system_cond : Returns both factors and conditions + + """ + + systems = factor_system_cond(eqs, gens, **kwargs) + return Or(*[And(*[Eq(eq, 0) for eq in sys]) for sys in systems]) + + +def factor_system_cond(eqs: Sequence[Expr | complex], gens: Sequence[Expr] = (), **kwargs: Any) -> list[list[Expr]]: + """ + Factorizes a polynomial system into irreducible components and returns + both generic and degenerate solutions. + + Parameters + ========== + + eqs : list + List of expressions to be factored. + Each expression is assumed to be equal to zero. + + gens : list, optional + Generator(s) of the polynomial ring. + If not provided, all free symbols will be used. + + **kwargs : dict, optional + Optional keyword arguments. + + Returns + ======= + + list[list[Expr]] + A list of lists of expressions, where each sublist represents + an irreducible subsystem. Includes both generic solutions and + degenerate cases requiring equality conditions on parameters. + + Examples + ======== + + >>> from sympy.solvers.polysys import factor_system_cond + >>> from sympy.abc import x, y, a, b, c + + >>> factor_system_cond([x**2 - 4, a*y, b], [x, y]) + [[x + 2, y, b], [x - 2, y, b], [x + 2, a, b], [x - 2, a, b]] + + >>> factor_system_cond([a*x*(x-1), b*y, c], [x, y]) + [[x - 1, y, c], [x, y, c], [x - 1, b, c], [x, b, c], [y, a, c], [a, b, c]] + + An empty list [] means no solution exists. + A list containing an empty list [[]] means any value of + the symbol(s) is a solution. + + See Also + ======== + + factor_system : Returns only generic solutions + factor_system_bool : Returns a Boolean combination representing all solutions + sympy.polys.polytools.factor : Factors a polynomial into irreducible factors + over the rational numbers + """ + systems_poly = _factor_system_poly_from_expr(eqs, gens, **kwargs) + systems = [[p.as_expr() for p in system] for system in systems_poly] + return systems + + +def _factor_system_poly_from_expr( + eqs: Sequence[Expr | complex], gens: Sequence[Expr], **kwargs: Any +) -> list[list[Poly]]: + """ + Convert expressions to polynomials and factor the system. + + Takes a sequence of expressions, converts them to + polynomials, and factors the resulting system. Handles both regular + polynomial systems and purely numerical cases. + """ + try: + polys, opts = parallel_poly_from_expr(eqs, *gens, **kwargs) + only_numbers = False + except (GeneratorsNeeded, PolificationFailed): + _u = Dummy('u') + polys, opts = parallel_poly_from_expr(eqs, [_u], **kwargs) + assert opts['domain'].is_Numerical + only_numbers = True + + if only_numbers: + return [[]] if all(p == 0 for p in polys) else [] + + return factor_system_poly(polys) + + +def factor_system_poly(polys: list[Poly]) -> list[list[Poly]]: + """ + Factors a system of polynomial equations into irreducible subsystems + + Core implementation that works directly with Poly instances. + + Parameters + ========== + + polys : list[Poly] + A list of Poly instances to be factored. + + Returns + ======= + + list[list[Poly]] + A list of lists of polynomials, where each sublist represents + an irreducible component of the solution. Includes both + generic and degenerate cases. + + Examples + ======== + + >>> from sympy import symbols, Poly, ZZ + >>> from sympy.solvers.polysys import factor_system_poly + >>> a, b, c, x = symbols('a b c x') + >>> p1 = Poly((a - 1)*(x - 2), x, domain=ZZ[a,b,c]) + >>> p2 = Poly((b - 3)*(x - 2), x, domain=ZZ[a,b,c]) + >>> p3 = Poly(c, x, domain=ZZ[a,b,c]) + + The equation to be solved for x is ``x - 2 = 0`` provided either + of the two conditions on the parameters ``a`` and ``b`` is nonzero + and the constant parameter ``c`` should be zero. + + >>> sys1, sys2 = factor_system_poly([p1, p2, p3]) + >>> sys1 + [Poly(x - 2, x, domain='ZZ[a,b,c]'), + Poly(c, x, domain='ZZ[a,b,c]')] + >>> sys2 + [Poly(a - 1, x, domain='ZZ[a,b,c]'), + Poly(b - 3, x, domain='ZZ[a,b,c]'), + Poly(c, x, domain='ZZ[a,b,c]')] + + An empty list [] when returned means no solution exists. + Whereas a list containing an empty list [[]] means any value is a solution. + + See Also + ======== + + factor_system : Returns only generic solutions + factor_system_bool : Returns a Boolean combination representing the solutions + factor_system_cond : Returns both generic and degenerate solutions + sympy.polys.polytools.factor : Factors a polynomial into irreducible factors + over the rational numbers + """ + if not all(isinstance(poly, Poly) for poly in polys): + raise TypeError("polys should be a list of Poly instances") + if not polys: + return [[]] + + base_domain = polys[0].domain + base_gens = polys[0].gens + if not all(poly.domain == base_domain and poly.gens == base_gens for poly in polys[1:]): + raise DomainError("All polynomials must have the same domain and generators") + + factor_sets = [] + for poly in polys: + constant, factors_mult = poly.factor_list() + + if constant.is_zero is True: + continue + elif constant.is_zero is False: + if not factors_mult: + return [] + factor_sets.append([f for f, _ in factors_mult]) + else: + constant = sqf_part(factor_terms(constant).as_coeff_Mul()[1]) + constp = Poly(constant, base_gens, domain=base_domain) + factors = [f for f, _ in factors_mult] + factors.append(constp) + factor_sets.append(factors) + + if not factor_sets: + return [[]] + + result = _factor_sets(factor_sets) + return _sort_systems(result) + + +def _factor_sets_slow(eqs: list[list]) -> set[frozenset]: + """ + Helper to find the minimal set of factorised subsystems that is + equivalent to the original system. + + The result is in DNF. + """ + if not eqs: + return {frozenset()} + systems_set = {frozenset(sys) for sys in cartes(*eqs)} + return {s1 for s1 in systems_set if not any(s1 > s2 for s2 in systems_set)} + + +def _factor_sets(eqs: list[list]) -> set[frozenset]: + """ + Helper that builds factor combinations. + """ + if not eqs: + return {frozenset()} + + current_set = min(eqs, key=len) + other_sets = [s for s in eqs if s is not current_set] + + stack = [(factor, [s for s in other_sets if factor not in s], {factor}) + for factor in current_set] + + result = set() + + while stack: + factor, remaining_sets, current_solution = stack.pop() + + if not remaining_sets: + result.add(frozenset(current_solution)) + continue + + next_set = min(remaining_sets, key=len) + next_remaining = [s for s in remaining_sets if s is not next_set] + + for next_factor in next_set: + valid_remaining = [s for s in next_remaining if next_factor not in s] + new_solution = current_solution | {next_factor} + stack.append((next_factor, valid_remaining, new_solution)) + + return {s1 for s1 in result if not any(s1 > s2 for s2 in result)} + + +def _sort_systems(systems: Iterable[Iterable[Poly]]) -> list[list[Poly]]: + """Sorts a list of lists of polynomials""" + systems_list = [sorted(s, key=_poly_sort_key, reverse=True) for s in systems] + return sorted(systems_list, key=_sys_sort_key, reverse=True) + + +def _poly_sort_key(poly): + """Sort key for polynomials""" + if poly.domain.is_FF: + poly = poly.set_domain(ZZ) + return poly.degree_list(), poly.rep.to_list() + + +def _sys_sort_key(sys): + """Sort key for lists of polynomials""" + return list(zip(*map(_poly_sort_key, sys))) diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/recurr.py b/.venv/lib/python3.13/site-packages/sympy/solvers/recurr.py new file mode 100644 index 0000000000000000000000000000000000000000..ba627bbd4cb0844f11a8743634f5f10328aadca8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/recurr.py @@ -0,0 +1,843 @@ +r""" +This module is intended for solving recurrences or, in other words, +difference equations. Currently supported are linear, inhomogeneous +equations with polynomial or rational coefficients. + +The solutions are obtained among polynomials, rational functions, +hypergeometric terms, or combinations of hypergeometric term which +are pairwise dissimilar. + +``rsolve_X`` functions were meant as a low level interface +for ``rsolve`` which would use Mathematica's syntax. + +Given a recurrence relation: + + .. math:: a_{k}(n) y(n+k) + a_{k-1}(n) y(n+k-1) + + ... + a_{0}(n) y(n) = f(n) + +where `k > 0` and `a_{i}(n)` are polynomials in `n`. To use +``rsolve_X`` we need to put all coefficients in to a list ``L`` of +`k+1` elements the following way: + + ``L = [a_{0}(n), ..., a_{k-1}(n), a_{k}(n)]`` + +where ``L[i]``, for `i=0, \ldots, k`, maps to +`a_{i}(n) y(n+i)` (`y(n+i)` is implicit). + +For example if we would like to compute `m`-th Bernoulli polynomial +up to a constant (example was taken from rsolve_poly docstring), +then we would use `b(n+1) - b(n) = m n^{m-1}` recurrence, which +has solution `b(n) = B_m + C`. + +Then ``L = [-1, 1]`` and `f(n) = m n^(m-1)` and finally for `m=4`: + +>>> from sympy import Symbol, bernoulli, rsolve_poly +>>> n = Symbol('n', integer=True) + +>>> rsolve_poly([-1, 1], 4*n**3, n) +C0 + n**4 - 2*n**3 + n**2 + +>>> bernoulli(4, n) +n**4 - 2*n**3 + n**2 - 1/30 + +For the sake of completeness, `f(n)` can be: + + [1] a polynomial -> rsolve_poly + [2] a rational function -> rsolve_ratio + [3] a hypergeometric function -> rsolve_hyper +""" +from collections import defaultdict + +from sympy.concrete import product +from sympy.core.singleton import S +from sympy.core.numbers import Rational, I +from sympy.core.symbol import Symbol, Wild, Dummy +from sympy.core.relational import Equality +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.sorting import default_sort_key +from sympy.core.sympify import sympify + +from sympy.simplify import simplify, hypersimp, hypersimilar # type: ignore +from sympy.solvers import solve, solve_undetermined_coeffs +from sympy.polys import Poly, quo, gcd, lcm, roots, resultant +from sympy.functions import binomial, factorial, FallingFactorial, RisingFactorial +from sympy.matrices import Matrix, casoratian +from sympy.utilities.iterables import numbered_symbols + + +def rsolve_poly(coeffs, f, n, shift=0, **hints): + r""" + Given linear recurrence operator `\operatorname{L}` of order + `k` with polynomial coefficients and inhomogeneous equation + `\operatorname{L} y = f`, where `f` is a polynomial, we seek for + all polynomial solutions over field `K` of characteristic zero. + + The algorithm performs two basic steps: + + (1) Compute degree `N` of the general polynomial solution. + (2) Find all polynomials of degree `N` or less + of `\operatorname{L} y = f`. + + There are two methods for computing the polynomial solutions. + If the degree bound is relatively small, i.e. it's smaller than + or equal to the order of the recurrence, then naive method of + undetermined coefficients is being used. This gives a system + of algebraic equations with `N+1` unknowns. + + In the other case, the algorithm performs transformation of the + initial equation to an equivalent one for which the system of + algebraic equations has only `r` indeterminates. This method is + quite sophisticated (in comparison with the naive one) and was + invented together by Abramov, Bronstein and Petkovsek. + + It is possible to generalize the algorithm implemented here to + the case of linear q-difference and differential equations. + + Lets say that we would like to compute `m`-th Bernoulli polynomial + up to a constant. For this we can use `b(n+1) - b(n) = m n^{m-1}` + recurrence, which has solution `b(n) = B_m + C`. For example: + + >>> from sympy import Symbol, rsolve_poly + >>> n = Symbol('n', integer=True) + + >>> rsolve_poly([-1, 1], 4*n**3, n) + C0 + n**4 - 2*n**3 + n**2 + + References + ========== + + .. [1] S. A. Abramov, M. Bronstein and M. Petkovsek, On polynomial + solutions of linear operator equations, in: T. Levelt, ed., + Proc. ISSAC '95, ACM Press, New York, 1995, 290-296. + + .. [2] M. Petkovsek, Hypergeometric solutions of linear recurrences + with polynomial coefficients, J. Symbolic Computation, + 14 (1992), 243-264. + + .. [3] M. Petkovsek, H. S. Wilf, D. Zeilberger, A = B, 1996. + + """ + f = sympify(f) + + if not f.is_polynomial(n): + return None + + homogeneous = f.is_zero + + r = len(coeffs) - 1 + + coeffs = [Poly(coeff, n) for coeff in coeffs] + + polys = [Poly(0, n)]*(r + 1) + terms = [(S.Zero, S.NegativeInfinity)]*(r + 1) + + for i in range(r + 1): + for j in range(i, r + 1): + polys[i] += coeffs[j]*(binomial(j, i).as_poly(n)) + + if not polys[i].is_zero: + (exp,), coeff = polys[i].LT() + terms[i] = (coeff, exp) + + d = b = terms[0][1] + + for i in range(1, r + 1): + if terms[i][1] > d: + d = terms[i][1] + + if terms[i][1] - i > b: + b = terms[i][1] - i + + d, b = int(d), int(b) + + x = Dummy('x') + + degree_poly = S.Zero + + for i in range(r + 1): + if terms[i][1] - i == b: + degree_poly += terms[i][0]*FallingFactorial(x, i) + + nni_roots = list(roots(degree_poly, x, filter='Z', + predicate=lambda r: r >= 0).keys()) + + if nni_roots: + N = [max(nni_roots)] + else: + N = [] + + if homogeneous: + N += [-b - 1] + else: + N += [f.as_poly(n).degree() - b, -b - 1] + + N = int(max(N)) + + if N < 0: + if homogeneous: + if hints.get('symbols', False): + return (S.Zero, []) + else: + return S.Zero + else: + return None + + if N <= r: + C = [] + y = E = S.Zero + + for i in range(N + 1): + C.append(Symbol('C' + str(i + shift))) + y += C[i] * n**i + + for i in range(r + 1): + E += coeffs[i].as_expr()*y.subs(n, n + i) + + solutions = solve_undetermined_coeffs(E - f, C, n) + + if solutions is not None: + _C = C + C = [c for c in C if (c not in solutions)] + result = y.subs(solutions) + else: + return None # TBD + else: + A = r + U = N + A + b + 1 + + nni_roots = list(roots(polys[r], filter='Z', + predicate=lambda r: r >= 0).keys()) + + if nni_roots != []: + a = max(nni_roots) + 1 + else: + a = S.Zero + + def _zero_vector(k): + return [S.Zero] * k + + def _one_vector(k): + return [S.One] * k + + def _delta(p, k): + B = S.One + D = p.subs(n, a + k) + + for i in range(1, k + 1): + B *= Rational(i - k - 1, i) + D += B * p.subs(n, a + k - i) + + return D + + alpha = {} + + for i in range(-A, d + 1): + I = _one_vector(d + 1) + + for k in range(1, d + 1): + I[k] = I[k - 1] * (x + i - k + 1)/k + + alpha[i] = S.Zero + + for j in range(A + 1): + for k in range(d + 1): + B = binomial(k, i + j) + D = _delta(polys[j].as_expr(), k) + + alpha[i] += I[k]*B*D + + V = Matrix(U, A, lambda i, j: int(i == j)) + + if homogeneous: + for i in range(A, U): + v = _zero_vector(A) + + for k in range(1, A + b + 1): + if i - k < 0: + break + + B = alpha[k - A].subs(x, i - k) + + for j in range(A): + v[j] += B * V[i - k, j] + + denom = alpha[-A].subs(x, i) + + for j in range(A): + V[i, j] = -v[j] / denom + else: + G = _zero_vector(U) + + for i in range(A, U): + v = _zero_vector(A) + g = S.Zero + + for k in range(1, A + b + 1): + if i - k < 0: + break + + B = alpha[k - A].subs(x, i - k) + + for j in range(A): + v[j] += B * V[i - k, j] + + g += B * G[i - k] + + denom = alpha[-A].subs(x, i) + + for j in range(A): + V[i, j] = -v[j] / denom + + G[i] = (_delta(f, i - A) - g) / denom + + P, Q = _one_vector(U), _zero_vector(A) + + for i in range(1, U): + P[i] = (P[i - 1] * (n - a - i + 1)/i).expand() + + for i in range(A): + Q[i] = Add(*[(v*p).expand() for v, p in zip(V[:, i], P)]) + + if not homogeneous: + h = Add(*[(g*p).expand() for g, p in zip(G, P)]) + + C = [Symbol('C' + str(i + shift)) for i in range(A)] + + g = lambda i: Add(*[c*_delta(q, i) for c, q in zip(C, Q)]) + + if homogeneous: + E = [g(i) for i in range(N + 1, U)] + else: + E = [g(i) + _delta(h, i) for i in range(N + 1, U)] + + if E != []: + solutions = solve(E, *C) + + if not solutions: + if homogeneous: + if hints.get('symbols', False): + return (S.Zero, []) + else: + return S.Zero + else: + return None + else: + solutions = {} + + if homogeneous: + result = S.Zero + else: + result = h + + _C = C[:] + for c, q in list(zip(C, Q)): + if c in solutions: + s = solutions[c]*q + C.remove(c) + else: + s = c*q + + result += s.expand() + + if C != _C: + # renumber so they are contiguous + result = result.xreplace(dict(zip(C, _C))) + C = _C[:len(C)] + + if hints.get('symbols', False): + return (result, C) + else: + return result + + +def rsolve_ratio(coeffs, f, n, **hints): + r""" + Given linear recurrence operator `\operatorname{L}` of order `k` + with polynomial coefficients and inhomogeneous equation + `\operatorname{L} y = f`, where `f` is a polynomial, we seek + for all rational solutions over field `K` of characteristic zero. + + This procedure accepts only polynomials, however if you are + interested in solving recurrence with rational coefficients + then use ``rsolve`` which will pre-process the given equation + and run this procedure with polynomial arguments. + + The algorithm performs two basic steps: + + (1) Compute polynomial `v(n)` which can be used as universal + denominator of any rational solution of equation + `\operatorname{L} y = f`. + + (2) Construct new linear difference equation by substitution + `y(n) = u(n)/v(n)` and solve it for `u(n)` finding all its + polynomial solutions. Return ``None`` if none were found. + + The algorithm implemented here is a revised version of the original + Abramov's algorithm, developed in 1989. The new approach is much + simpler to implement and has better overall efficiency. This + method can be easily adapted to the q-difference equations case. + + Besides finding rational solutions alone, this functions is + an important part of Hyper algorithm where it is used to find + a particular solution for the inhomogeneous part of a recurrence. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy.solvers.recurr import rsolve_ratio + >>> rsolve_ratio([-2*x**3 + x**2 + 2*x - 1, 2*x**3 + x**2 - 6*x, + ... - 2*x**3 - 11*x**2 - 18*x - 9, 2*x**3 + 13*x**2 + 22*x + 8], 0, x) + C0*(2*x - 3)/(2*(x**2 - 1)) + + References + ========== + + .. [1] S. A. Abramov, Rational solutions of linear difference + and q-difference equations with polynomial coefficients, + in: T. Levelt, ed., Proc. ISSAC '95, ACM Press, New York, + 1995, 285-289 + + See Also + ======== + + rsolve_hyper + """ + f = sympify(f) + + if not f.is_polynomial(n): + return None + + coeffs = list(map(sympify, coeffs)) + + r = len(coeffs) - 1 + + A, B = coeffs[r], coeffs[0] + A = A.subs(n, n - r).expand() + + h = Dummy('h') + + res = resultant(A, B.subs(n, n + h), n) + + if not res.is_polynomial(h): + p, q = res.as_numer_denom() + res = quo(p, q, h) + + nni_roots = list(roots(res, h, filter='Z', + predicate=lambda r: r >= 0).keys()) + + if not nni_roots: + return rsolve_poly(coeffs, f, n, **hints) + else: + C, numers = S.One, [S.Zero]*(r + 1) + + for i in range(int(max(nni_roots)), -1, -1): + d = gcd(A, B.subs(n, n + i), n) + + A = quo(A, d, n) + B = quo(B, d.subs(n, n - i), n) + + C *= Mul(*[d.subs(n, n - j) for j in range(i + 1)]) + + denoms = [C.subs(n, n + i) for i in range(r + 1)] + + for i in range(r + 1): + g = gcd(coeffs[i], denoms[i], n) + + numers[i] = quo(coeffs[i], g, n) + denoms[i] = quo(denoms[i], g, n) + + for i in range(r + 1): + numers[i] *= Mul(*(denoms[:i] + denoms[i + 1:])) + + result = rsolve_poly(numers, f * Mul(*denoms), n, **hints) + + if result is not None: + if hints.get('symbols', False): + return (simplify(result[0] / C), result[1]) + else: + return simplify(result / C) + else: + return None + + +def rsolve_hyper(coeffs, f, n, **hints): + r""" + Given linear recurrence operator `\operatorname{L}` of order `k` + with polynomial coefficients and inhomogeneous equation + `\operatorname{L} y = f` we seek for all hypergeometric solutions + over field `K` of characteristic zero. + + The inhomogeneous part can be either hypergeometric or a sum + of a fixed number of pairwise dissimilar hypergeometric terms. + + The algorithm performs three basic steps: + + (1) Group together similar hypergeometric terms in the + inhomogeneous part of `\operatorname{L} y = f`, and find + particular solution using Abramov's algorithm. + + (2) Compute generating set of `\operatorname{L}` and find basis + in it, so that all solutions are linearly independent. + + (3) Form final solution with the number of arbitrary + constants equal to dimension of basis of `\operatorname{L}`. + + Term `a(n)` is hypergeometric if it is annihilated by first order + linear difference equations with polynomial coefficients or, in + simpler words, if consecutive term ratio is a rational function. + + The output of this procedure is a linear combination of fixed + number of hypergeometric terms. However the underlying method + can generate larger class of solutions - D'Alembertian terms. + + Note also that this method not only computes the kernel of the + inhomogeneous equation, but also reduces in to a basis so that + solutions generated by this procedure are linearly independent + + Examples + ======== + + >>> from sympy.solvers import rsolve_hyper + >>> from sympy.abc import x + + >>> rsolve_hyper([-1, -1, 1], 0, x) + C0*(1/2 - sqrt(5)/2)**x + C1*(1/2 + sqrt(5)/2)**x + + >>> rsolve_hyper([-1, 1], 1 + x, x) + C0 + x*(x + 1)/2 + + References + ========== + + .. [1] M. Petkovsek, Hypergeometric solutions of linear recurrences + with polynomial coefficients, J. Symbolic Computation, + 14 (1992), 243-264. + + .. [2] M. Petkovsek, H. S. Wilf, D. Zeilberger, A = B, 1996. + """ + coeffs = list(map(sympify, coeffs)) + + f = sympify(f) + + r, kernel, symbols = len(coeffs) - 1, [], set() + + if not f.is_zero: + if f.is_Add: + similar = {} + + for g in f.expand().args: + if not g.is_hypergeometric(n): + return None + + for h in similar.keys(): + if hypersimilar(g, h, n): + similar[h] += g + break + else: + similar[g] = S.Zero + + inhomogeneous = [g + h for g, h in similar.items()] + elif f.is_hypergeometric(n): + inhomogeneous = [f] + else: + return None + + for i, g in enumerate(inhomogeneous): + coeff, polys = S.One, coeffs[:] + denoms = [S.One]*(r + 1) + + s = hypersimp(g, n) + + for j in range(1, r + 1): + coeff *= s.subs(n, n + j - 1) + + p, q = coeff.as_numer_denom() + + polys[j] *= p + denoms[j] = q + + for j in range(r + 1): + polys[j] *= Mul(*(denoms[:j] + denoms[j + 1:])) + + # FIXME: The call to rsolve_ratio below should suffice (rsolve_poly + # call can be removed) but the XFAIL test_rsolve_ratio_missed must + # be fixed first. + R = rsolve_ratio(polys, Mul(*denoms), n, symbols=True) + if R is not None: + R, syms = R + if syms: + R = R.subs(zip(syms, [0]*len(syms))) + else: + R = rsolve_poly(polys, Mul(*denoms), n) + + if R: + inhomogeneous[i] *= R + else: + return None + + result = Add(*inhomogeneous) + result = simplify(result) + else: + result = S.Zero + + Z = Dummy('Z') + + p, q = coeffs[0], coeffs[r].subs(n, n - r + 1) + + p_factors = list(roots(p, n).keys()) + q_factors = list(roots(q, n).keys()) + + factors = [(S.One, S.One)] + + for p in p_factors: + for q in q_factors: + if p.is_integer and q.is_integer and p <= q: + continue + else: + factors += [(n - p, n - q)] + + p = [(n - p, S.One) for p in p_factors] + q = [(S.One, n - q) for q in q_factors] + + factors = p + factors + q + + for A, B in factors: + polys, degrees = [], [] + D = A*B.subs(n, n + r - 1) + + for i in range(r + 1): + a = Mul(*[A.subs(n, n + j) for j in range(i)]) + b = Mul(*[B.subs(n, n + j) for j in range(i, r)]) + + poly = quo(coeffs[i]*a*b, D, n) + polys.append(poly.as_poly(n)) + + if not poly.is_zero: + degrees.append(polys[i].degree()) + + if degrees: + d, poly = max(degrees), S.Zero + else: + return None + + for i in range(r + 1): + coeff = polys[i].nth(d) + + if coeff is not S.Zero: + poly += coeff * Z**i + + for z in roots(poly, Z).keys(): + if z.is_zero: + continue + + recurr_coeffs = [polys[i].as_expr()*z**i for i in range(r + 1)] + if d == 0 and 0 != Add(*[recurr_coeffs[j]*j for j in range(1, r + 1)]): + # faster inline check (than calling rsolve_poly) for a + # constant solution to a constant coefficient recurrence. + sol = [Symbol("C" + str(len(symbols)))] + else: + sol, syms = rsolve_poly(recurr_coeffs, 0, n, len(symbols), symbols=True) + sol = sol.collect(syms) + sol = [sol.coeff(s) for s in syms] + + for C in sol: + ratio = z * A * C.subs(n, n + 1) / B / C + ratio = simplify(ratio) + # If there is a nonnegative root in the denominator of the ratio, + # this indicates that the term y(n_root) is zero, and one should + # start the product with the term y(n_root + 1). + n0 = 0 + for n_root in roots(ratio.as_numer_denom()[1], n).keys(): + if n_root.has(I): + return None + elif (n0 < (n_root + 1)) == True: + n0 = n_root + 1 + K = product(ratio, (n, n0, n - 1)) + if K.has(factorial, FallingFactorial, RisingFactorial): + K = simplify(K) + + if casoratian(kernel + [K], n, zero=False) != 0: + kernel.append(K) + + kernel.sort(key=default_sort_key) + sk = list(zip(numbered_symbols('C'), kernel)) + + for C, ker in sk: + result += C * ker + + if hints.get('symbols', False): + # XXX: This returns the symbols in a non-deterministic order + symbols |= {s for s, k in sk} + return (result, list(symbols)) + else: + return result + + +def rsolve(f, y, init=None): + r""" + Solve univariate recurrence with rational coefficients. + + Given `k`-th order linear recurrence `\operatorname{L} y = f`, + or equivalently: + + .. math:: a_{k}(n) y(n+k) + a_{k-1}(n) y(n+k-1) + + \cdots + a_{0}(n) y(n) = f(n) + + where `a_{i}(n)`, for `i=0, \ldots, k`, are polynomials or rational + functions in `n`, and `f` is a hypergeometric function or a sum + of a fixed number of pairwise dissimilar hypergeometric terms in + `n`, finds all solutions or returns ``None``, if none were found. + + Initial conditions can be given as a dictionary in two forms: + + (1) ``{ n_0 : v_0, n_1 : v_1, ..., n_m : v_m}`` + (2) ``{y(n_0) : v_0, y(n_1) : v_1, ..., y(n_m) : v_m}`` + + or as a list ``L`` of values: + + ``L = [v_0, v_1, ..., v_m]`` + + where ``L[i] = v_i``, for `i=0, \ldots, m`, maps to `y(n_i)`. + + Examples + ======== + + Lets consider the following recurrence: + + .. math:: (n - 1) y(n + 2) - (n^2 + 3 n - 2) y(n + 1) + + 2 n (n + 1) y(n) = 0 + + >>> from sympy import Function, rsolve + >>> from sympy.abc import n + >>> y = Function('y') + + >>> f = (n - 1)*y(n + 2) - (n**2 + 3*n - 2)*y(n + 1) + 2*n*(n + 1)*y(n) + + >>> rsolve(f, y(n)) + 2**n*C0 + C1*factorial(n) + + >>> rsolve(f, y(n), {y(0):0, y(1):3}) + 3*2**n - 3*factorial(n) + + See Also + ======== + + rsolve_poly, rsolve_ratio, rsolve_hyper + + """ + if isinstance(f, Equality): + f = f.lhs - f.rhs + + n = y.args[0] + k = Wild('k', exclude=(n,)) + + # Preprocess user input to allow things like + # y(n) + a*(y(n + 1) + y(n - 1))/2 + f = f.expand().collect(y.func(Wild('m', integer=True))) + + h_part = defaultdict(list) + i_part = [] + for g in Add.make_args(f): + coeff, dep = g.as_coeff_mul(y.func) + if not dep: + i_part.append(coeff) + continue + for h in dep: + if h.is_Function and h.func == y.func: + result = h.args[0].match(n + k) + if result is not None: + h_part[int(result[k])].append(coeff) + continue + raise ValueError( + "'%s(%s + k)' expected, got '%s'" % (y.func, n, h)) + for k in h_part: + h_part[k] = Add(*h_part[k]) + h_part.default_factory = lambda: 0 + i_part = Add(*i_part) + + for k, coeff in h_part.items(): + h_part[k] = simplify(coeff) + + common = S.One + + if not i_part.is_zero and not i_part.is_hypergeometric(n) and \ + not (i_part.is_Add and all((x.is_hypergeometric(n) for x in i_part.expand().args))): + raise ValueError("The independent term should be a sum of hypergeometric functions, got '%s'" % i_part) + + for coeff in h_part.values(): + if coeff.is_rational_function(n): + if not coeff.is_polynomial(n): + common = lcm(common, coeff.as_numer_denom()[1], n) + else: + raise ValueError( + "Polynomial or rational function expected, got '%s'" % coeff) + + i_numer, i_denom = i_part.as_numer_denom() + + if i_denom.is_polynomial(n): + common = lcm(common, i_denom, n) + + if common is not S.One: + for k, coeff in h_part.items(): + numer, denom = coeff.as_numer_denom() + h_part[k] = numer*quo(common, denom, n) + + i_part = i_numer*quo(common, i_denom, n) + + K_min = min(h_part.keys()) + + if K_min < 0: + K = abs(K_min) + + H_part = defaultdict(lambda: S.Zero) + i_part = i_part.subs(n, n + K).expand() + common = common.subs(n, n + K).expand() + + for k, coeff in h_part.items(): + H_part[k + K] = coeff.subs(n, n + K).expand() + else: + H_part = h_part + + K_max = max(H_part.keys()) + coeffs = [H_part[i] for i in range(K_max + 1)] + + result = rsolve_hyper(coeffs, -i_part, n, symbols=True) + + if result is None: + return None + + solution, symbols = result + + if init in ({}, []): + init = None + + if symbols and init is not None: + if isinstance(init, list): + init = {i: init[i] for i in range(len(init))} + + equations = [] + + for k, v in init.items(): + try: + i = int(k) + except TypeError: + if k.is_Function and k.func == y.func: + i = int(k.args[0]) + else: + raise ValueError("Integer or term expected, got '%s'" % k) + + eq = solution.subs(n, i) - v + if eq.has(S.NaN): + eq = solution.limit(n, i) - v + equations.append(eq) + + result = solve(equations, *symbols) + + if not result: + return None + else: + solution = solution.subs(result) + + return solution diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/simplex.py b/.venv/lib/python3.13/site-packages/sympy/solvers/simplex.py new file mode 100644 index 0000000000000000000000000000000000000000..c8e652cb626507d7829f9bc1c78fc6f49809865f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/simplex.py @@ -0,0 +1,1104 @@ +"""Tools for optimizing a linear function for a given simplex. + +For the linear objective function ``f`` with linear constraints +expressed using `Le`, `Ge` or `Eq` can be found with ``lpmin`` or +``lpmax``. The symbols are **unbounded** unless specifically +constrained. + +As an alternative, the matrices describing the objective and the +constraints, and an optional list of bounds can be passed to +``linprog`` which will solve for the minimization of ``C*x`` +under constraints ``A*x <= b`` and/or ``Aeq*x = beq``, and +individual bounds for variables given as ``(lo, hi)``. The values +returned are **nonnegative** unless bounds are provided that +indicate otherwise. + +Errors that might be raised are UnboundedLPError when there is no +finite solution for the system or InfeasibleLPError when the +constraints represent impossible conditions (i.e. a non-existent + simplex). + +Here is a simple 1-D system: minimize `x` given that ``x >= 1``. + + >>> from sympy.solvers.simplex import lpmin, linprog + >>> from sympy.abc import x + + The function and a list with the constraint is passed directly + to `lpmin`: + + >>> lpmin(x, [x >= 1]) + (1, {x: 1}) + + For `linprog` the matrix for the objective is `[1]` and the + uivariate constraint can be passed as a bound with None acting + as infinity: + + >>> linprog([1], bounds=(1, None)) + (1, [1]) + + Or the matrices, corresponding to ``x >= 1`` expressed as + ``-x <= -1`` as required by the routine, can be passed: + + >>> linprog([1], [-1], [-1]) + (1, [1]) + + If there is no limit for the objective, an error is raised. + In this case there is a valid region of interest (simplex) + but no limit to how small ``x`` can be: + + >>> lpmin(x, []) + Traceback (most recent call last): + ... + sympy.solvers.simplex.UnboundedLPError: + Objective function can assume arbitrarily large values! + + An error is raised if there is no possible solution: + + >>> lpmin(x,[x<=1,x>=2]) + Traceback (most recent call last): + ... + sympy.solvers.simplex.InfeasibleLPError: + Inconsistent/False constraint +""" + +from sympy.core import sympify +from sympy.core.exprtools import factor_terms +from sympy.core.relational import Le, Ge, Eq +from sympy.core.singleton import S +from sympy.core.symbol import Dummy +from sympy.core.sorting import ordered +from sympy.functions.elementary.complexes import sign +from sympy.matrices.dense import Matrix, zeros +from sympy.solvers.solveset import linear_eq_to_matrix +from sympy.utilities.iterables import numbered_symbols +from sympy.utilities.misc import filldedent + + +class UnboundedLPError(Exception): + """ + A linear programming problem is said to be unbounded if its objective + function can assume arbitrarily large values. + + Example + ======= + + Suppose you want to maximize + 2x + subject to + x >= 0 + + There's no upper limit that 2x can take. + """ + + pass + + +class InfeasibleLPError(Exception): + """ + A linear programming problem is considered infeasible if its + constraint set is empty. That is, if the set of all vectors + satisfying the constraints is empty, then the problem is infeasible. + + Example + ======= + + Suppose you want to maximize + x + subject to + x >= 10 + x <= 9 + + No x can satisfy those constraints. + """ + + pass + + +def _pivot(M, i, j): + """ + The pivot element `M[i, j]` is inverted and the rest of the matrix + modified and returned as a new matrix; original is left unmodified. + + Example + ======= + + >>> from sympy.matrices.dense import Matrix + >>> from sympy.solvers.simplex import _pivot + >>> from sympy import var + >>> Matrix(3, 3, var('a:i')) + Matrix([ + [a, b, c], + [d, e, f], + [g, h, i]]) + >>> _pivot(_, 1, 0) + Matrix([ + [-a/d, -a*e/d + b, -a*f/d + c], + [ 1/d, e/d, f/d], + [-g/d, h - e*g/d, i - f*g/d]]) + """ + Mi, Mj, Mij = M[i, :], M[:, j], M[i, j] + if Mij == 0: + raise ZeroDivisionError( + "Tried to pivot about zero-valued entry.") + A = M - Mj * (Mi / Mij) + A[i, :] = Mi / Mij + A[:, j] = -Mj / Mij + A[i, j] = 1 / Mij + return A + + +def _choose_pivot_row(A, B, candidate_rows, pivot_col, Y): + # Choose row with smallest ratio + # If there are ties, pick using Bland's rule + return min(candidate_rows, key=lambda i: (B[i] / A[i, pivot_col], Y[i])) + + +def _simplex(A, B, C, D=None, dual=False): + """Return ``(o, x, y)`` obtained from the two-phase simplex method + using Bland's rule: ``o`` is the minimum value of primal, + ``Cx - D``, under constraints ``Ax <= B`` (with ``x >= 0``) and + the maximum of the dual, ``y^{T}B - D``, under constraints + ``A^{T}*y >= C^{T}`` (with ``y >= 0``). To compute the dual of + the system, pass `dual=True` and ``(o, y, x)`` will be returned. + + Note: the nonnegative constraints for ``x`` and ``y`` supercede + any values of ``A`` and ``B`` that are inconsistent with that + assumption, so if a constraint of ``x >= -1`` is represented + in ``A`` and ``B``, no value will be obtained that is negative; if + a constraint of ``x <= -1`` is represented, an error will be + raised since no solution is possible. + + This routine relies on the ability of determining whether an + expression is 0 or not. This is guaranteed if the input contains + only Float or Rational entries. It will raise a TypeError if + a relationship does not evaluate to True or False. + + Examples + ======== + + >>> from sympy.solvers.simplex import _simplex + >>> from sympy import Matrix + + Consider the simple minimization of ``f = x + y + 1`` under the + constraint that ``y + 2*x >= 4``. This is the "standard form" of + a minimization. + + In the nonnegative quadrant, this inequality describes a area above + a triangle with vertices at (0, 4), (0, 0) and (2, 0). The minimum + of ``f`` occurs at (2, 0). Define A, B, C, D for the standard + minimization: + + >>> A = Matrix([[2, 1]]) + >>> B = Matrix([4]) + >>> C = Matrix([[1, 1]]) + >>> D = Matrix([-1]) + + Confirm that this is the system of interest: + + >>> from sympy.abc import x, y + >>> X = Matrix([x, y]) + >>> (C*X - D)[0] + x + y + 1 + >>> [i >= j for i, j in zip(A*X, B)] + [2*x + y >= 4] + + Since `_simplex` will do a minimization for constraints given as + ``A*x <= B``, the signs of ``A`` and ``B`` must be negated since + the currently correspond to a greater-than inequality: + + >>> _simplex(-A, -B, C, D) + (3, [2, 0], [1/2]) + + The dual of minimizing ``f`` is maximizing ``F = c*y - d`` for + ``a*y <= b`` where ``a``, ``b``, ``c``, ``d`` are derived from the + transpose of the matrix representation of the standard minimization: + + >>> tr = lambda a, b, c, d: [i.T for i in (a, c, b, d)] + >>> a, b, c, d = tr(A, B, C, D) + + This time ``a*x <= b`` is the expected inequality for the `_simplex` + method, but to maximize ``F``, the sign of ``c`` and ``d`` must be + changed (so that minimizing the negative will give the negative of + the maximum of ``F``): + + >>> _simplex(a, b, -c, -d) + (-3, [1/2], [2, 0]) + + The negative of ``F`` and the min of ``f`` are the same. The dual + point `[1/2]` is the value of ``y`` that minimized ``F = c*y - d`` + under constraints a*x <= b``: + + >>> y = Matrix(['y']) + >>> (c*y - d)[0] + 4*y + 1 + >>> [i <= j for i, j in zip(a*y,b)] + [2*y <= 1, y <= 1] + + In this 1-dimensional dual system, the more restrictive constraint is + the first which limits ``y`` between 0 and 1/2 and the maximum of + ``F`` is attained at the nonzero value, hence is ``4*(1/2) + 1 = 3``. + + In this case the values for ``x`` and ``y`` were the same when the + dual representation was solved. This is not always the case (though + the value of the function will be the same). + + >>> l = [[1, 1], [-1, 1], [0, 1], [-1, 0]], [5, 1, 2, -1], [[1, 1]], [-1] + >>> A, B, C, D = [Matrix(i) for i in l] + >>> _simplex(A, B, -C, -D) + (-6, [3, 2], [1, 0, 0, 0]) + >>> _simplex(A, B, -C, -D, dual=True) # [5, 0] != [3, 2] + (-6, [1, 0, 0, 0], [5, 0]) + + In both cases the function has the same value: + + >>> Matrix(C)*Matrix([3, 2]) == Matrix(C)*Matrix([5, 0]) + True + + See Also + ======== + _lp - poses min/max problem in form compatible with _simplex + lpmin - minimization which calls _lp + lpmax - maximimzation which calls _lp + + References + ========== + + .. [1] Thomas S. Ferguson, LINEAR PROGRAMMING: A Concise Introduction + web.tecnico.ulisboa.pt/mcasquilho/acad/or/ftp/FergusonUCLA_lp.pdf + + """ + A, B, C, D = [Matrix(i) for i in (A, B, C, D or [0])] + if dual: + _o, d, p = _simplex(-A.T, C.T, B.T, -D) + return -_o, d, p + + if A and B: + M = Matrix([[A, B], [C, D]]) + else: + if A or B: + raise ValueError("must give A and B") + # no constraints given + M = Matrix([[C, D]]) + n = M.cols - 1 + m = M.rows - 1 + + if not all(i.is_Float or i.is_Rational for i in M): + # with literal Float and Rational we are guaranteed the + # ability of determining whether an expression is 0 or not + raise TypeError(filldedent(""" + Only rationals and floats are allowed. + """ + ) + ) + + # x variables have priority over y variables during Bland's rule + # since False < True + X = [(False, j) for j in range(n)] + Y = [(True, i) for i in range(m)] + + # Phase 1: find a feasible solution or determine none exist + + ## keep track of last pivot row and column + last = None + + while True: + B = M[:-1, -1] + A = M[:-1, :-1] + if all(B[i] >= 0 for i in range(B.rows)): + # We have found a feasible solution + break + + # Find k: first row with a negative rightmost entry + for k in range(B.rows): + if B[k] < 0: + break # use current value of k below + else: + pass # error will raise below + + # Choose pivot column, c + piv_cols = [_ for _ in range(A.cols) if A[k, _] < 0] + if not piv_cols: + raise InfeasibleLPError(filldedent(""" + The constraint set is empty!""")) + _, c = min((X[i], i) for i in piv_cols) # Bland's rule + + # Choose pivot row, r + piv_rows = [_ for _ in range(A.rows) if A[_, c] > 0 and B[_] > 0] + piv_rows.append(k) + r = _choose_pivot_row(A, B, piv_rows, c, Y) + + # check for oscillation + if (r, c) == last: + # Not sure what to do here; it looks like there will be + # oscillations; see o1 test added at this commit to + # see a system with no solution and the o2 for one + # with a solution. In the case of o2, the solution + # from linprog is the same as the one from lpmin, but + # the matrices created in the lpmin case are different + # than those created without replacements in linprog and + # the matrices in the linprog case lead to oscillations. + # If the matrices could be re-written in linprog like + # lpmin does, this behavior could be avoided and then + # perhaps the oscillating case would only occur when + # there is no solution. For now, the output is checked + # before exit if oscillations were detected and an + # error is raised there if the solution was invalid. + # + # cf section 6 of Ferguson for a non-cycling modification + last = True + break + last = r, c + + M = _pivot(M, r, c) + X[c], Y[r] = Y[r], X[c] + + # Phase 2: from a feasible solution, pivot to optimal + while True: + B = M[:-1, -1] + A = M[:-1, :-1] + C = M[-1, :-1] + + # Choose a pivot column, c + piv_cols = [_ for _ in range(n) if C[_] < 0] + if not piv_cols: + break + _, c = min((X[i], i) for i in piv_cols) # Bland's rule + + # Choose a pivot row, r + piv_rows = [_ for _ in range(m) if A[_, c] > 0] + if not piv_rows: + raise UnboundedLPError(filldedent(""" + Objective function can assume + arbitrarily large values!""")) + r = _choose_pivot_row(A, B, piv_rows, c, Y) + + M = _pivot(M, r, c) + X[c], Y[r] = Y[r], X[c] + + argmax = [None] * n + argmin_dual = [None] * m + + for i, (v, n) in enumerate(X): + if v == False: + argmax[n] = 0 + else: + argmin_dual[n] = M[-1, i] + + for i, (v, n) in enumerate(Y): + if v == True: + argmin_dual[n] = 0 + else: + argmax[n] = M[i, -1] + + if last and not all(i >= 0 for i in argmax + argmin_dual): + raise InfeasibleLPError(filldedent(""" + Oscillating system led to invalid solution. + If you believe there was a valid solution, please + report this as a bug.""")) + return -M[-1, -1], argmax, argmin_dual + + +## routines that use _simplex or support those that do + + +def _abcd(M, list=False): + """return parts of M as matrices or lists + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.solvers.simplex import _abcd + + >>> m = Matrix(3, 3, range(9)); m + Matrix([ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> a, b, c, d = _abcd(m) + >>> a + Matrix([ + [0, 1], + [3, 4]]) + >>> b + Matrix([ + [2], + [5]]) + >>> c + Matrix([[6, 7]]) + >>> d + Matrix([[8]]) + + The matrices can be returned as compact lists, too: + + >>> L = a, b, c, d = _abcd(m, list=True); L + ([[0, 1], [3, 4]], [2, 5], [[6, 7]], [8]) + """ + + def aslist(i): + l = i.tolist() + if len(l[0]) == 1: # col vector + return [i[0] for i in l] + return l + + m = M[:-1, :-1], M[:-1, -1], M[-1, :-1], M[-1:, -1:] + if not list: + return m + return tuple([aslist(i) for i in m]) + + +def _m(a, b, c, d=None): + """return Matrix([[a, b], [c, d]]) from matrices + in Matrix or list form. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.solvers.simplex import _abcd, _m + >>> m = Matrix(3, 3, range(9)) + >>> L = _abcd(m, list=True); L + ([[0, 1], [3, 4]], [2, 5], [[6, 7]], [8]) + >>> _abcd(m) + (Matrix([ + [0, 1], + [3, 4]]), Matrix([ + [2], + [5]]), Matrix([[6, 7]]), Matrix([[8]])) + >>> assert m == _m(*L) == _m(*_) + """ + a, b, c, d = [Matrix(i) for i in (a, b, c, d or [0])] + return Matrix([[a, b], [c, d]]) + + +def _primal_dual(M, factor=True): + """return primal and dual function and constraints + assuming that ``M = Matrix([[A, b], [c, d]])`` and the + function ``c*x - d`` is being minimized with ``Ax >= b`` + for nonnegative values of ``x``. The dual and its + constraints will be for maximizing `b.T*y - d` subject + to ``A.T*y <= c.T``. + + Examples + ======== + + >>> from sympy.solvers.simplex import _primal_dual, lpmin, lpmax + >>> from sympy import Matrix + + The following matrix represents the primal task of + minimizing x + y + 7 for y >= x + 1 and y >= -2*x + 3. + The dual task seeks to maximize x + 3*y + 7 with + 2*y - x <= 1 and and x + y <= 1: + + >>> M = Matrix([ + ... [-1, 1, 1], + ... [ 2, 1, 3], + ... [ 1, 1, -7]]) + >>> p, d = _primal_dual(M) + + The minimum of the primal and maximum of the dual are the same + (though they occur at different points): + + >>> lpmin(*p) + (28/3, {x1: 2/3, x2: 5/3}) + >>> lpmax(*d) + (28/3, {y1: 1/3, y2: 2/3}) + + If the equivalent (but canonical) inequalities are + desired, leave `factor=True`, otherwise the unmodified + inequalities for M will be returned. + + >>> m = Matrix([ + ... [-3, -2, 4, -2], + ... [ 2, 0, 0, -2], + ... [ 0, 1, -3, 0]]) + + >>> _primal_dual(m, False) # last condition is 2*x1 >= -2 + ((x2 - 3*x3, + [-3*x1 - 2*x2 + 4*x3 >= -2, 2*x1 >= -2]), + (-2*y1 - 2*y2, + [-3*y1 + 2*y2 <= 0, -2*y1 <= 1, 4*y1 <= -3])) + + >>> _primal_dual(m) # condition now x1 >= -1 + ((x2 - 3*x3, + [-3*x1 - 2*x2 + 4*x3 >= -2, x1 >= -1]), + (-2*y1 - 2*y2, + [-3*y1 + 2*y2 <= 0, -2*y1 <= 1, 4*y1 <= -3])) + + If you pass the transpose of the matrix, the primal will be + identified as the standard minimization problem and the + dual as the standard maximization: + + >>> _primal_dual(m.T) + ((-2*x1 - 2*x2, + [-3*x1 + 2*x2 >= 0, -2*x1 >= 1, 4*x1 >= -3]), + (y2 - 3*y3, + [-3*y1 - 2*y2 + 4*y3 <= -2, y1 <= -1])) + + A matrix must have some size or else None will be returned for + the functions: + + >>> _primal_dual(Matrix([[1, 2]])) + ((x1 - 2, []), (-2, [])) + + >>> _primal_dual(Matrix([])) + ((None, []), (None, [])) + + References + ========== + + .. [1] David Galvin, Relations between Primal and Dual + www3.nd.edu/~dgalvin1/30210/30210_F07/presentations/dual_opt.pdf + """ + if not M: + return (None, []), (None, []) + if not hasattr(M, "shape"): + if len(M) not in (3, 4): + raise ValueError("expecting Matrix or 3 or 4 lists") + M = _m(*M) + m, n = [i - 1 for i in M.shape] + A, b, c, d = _abcd(M) + d = d[0] + _ = lambda x: numbered_symbols(x, start=1) + x = Matrix([i for i, j in zip(_("x"), range(n))]) + yT = Matrix([i for i, j in zip(_("y"), range(m))]).T + + def ineq(L, r, op): + rv = [] + for r in (op(i, j) for i, j in zip(L, r)): + if r == True: + continue + elif r == False: + return [False] + if factor: + f = factor_terms(r) + if f.lhs.is_Mul and f.rhs % f.lhs.args[0] == 0: + assert len(f.lhs.args) == 2, f.lhs + k = f.lhs.args[0] + r = r.func(sign(k) * f.lhs.args[1], f.rhs // abs(k)) + rv.append(r) + return rv + + eq = lambda x, d: x[0] - d if x else -d + F = eq(c * x, d) + f = eq(yT * b, d) + return (F, ineq(A * x, b, Ge)), (f, ineq(yT * A, c, Le)) + + +def _rel_as_nonpos(constr, syms): + """return `(np, d, aux)` where `np` is a list of nonpositive + expressions that represent the given constraints (possibly + rewritten in terms of auxilliary variables) expressible with + nonnegative symbols, and `d` is a dictionary mapping a given + symbols to an expression with an auxilliary variable. In some + cases a symbol will be used as part of the change of variables, + e.g. x: x - z1 instead of x: z1 - z2. + + If any constraint is False/empty, return None. All variables in + ``constr`` are assumed to be unbounded unless explicitly indicated + otherwise with a univariate constraint, e.g. ``x >= 0`` will + restrict ``x`` to nonnegative values. + + The ``syms`` must be included so all symbols can be given an + unbounded assumption if they are not otherwise bound with + univariate conditions like ``x <= 3``. + + Examples + ======== + + >>> from sympy.solvers.simplex import _rel_as_nonpos + >>> from sympy.abc import x, y + >>> _rel_as_nonpos([x >= y, x >= 0, y >= 0], (x, y)) + ([-x + y], {}, []) + >>> _rel_as_nonpos([x >= 3, x <= 5], [x]) + ([_z1 - 2], {x: _z1 + 3}, [_z1]) + >>> _rel_as_nonpos([x <= 5], [x]) + ([], {x: 5 - _z1}, [_z1]) + >>> _rel_as_nonpos([x >= 1], [x]) + ([], {x: _z1 + 1}, [_z1]) + """ + r = {} # replacements to handle change of variables + np = [] # nonpositive expressions + aux = [] # auxilliary symbols added + ui = numbered_symbols("z", start=1, cls=Dummy) # auxilliary symbols + univariate = {} # {x: interval} for univariate constraints + unbound = [] # symbols designated as unbound + syms = set(syms) # the expected syms of the system + + # separate out univariates + for i in constr: + if i == True: + continue # ignore + if i == False: + return # no solution + if i.has(S.Infinity, S.NegativeInfinity): + raise ValueError("only finite bounds are permitted") + if isinstance(i, (Le, Ge)): + i = i.lts - i.gts + freei = i.free_symbols + if freei - syms: + raise ValueError( + "unexpected symbol(s) in constraint: %s" % (freei - syms) + ) + if len(freei) > 1: + np.append(i) + elif freei: + x = freei.pop() + if x in unbound: + continue # will handle later + ivl = Le(i, 0, evaluate=False).as_set() + if x not in univariate: + univariate[x] = ivl + else: + univariate[x] &= ivl + elif i: + return False + else: + raise TypeError(filldedent(""" + only equalities like Eq(x, y) or non-strict + inequalities like x >= y are allowed in lp, not %s""" % i)) + + # introduce auxilliary variables as needed for univariate + # inequalities + for x in syms: + i = univariate.get(x, True) + if not i: + return None # no solution possible + if i == True: + unbound.append(x) + continue + a, b = i.inf, i.sup + if a.is_infinite: + u = next(ui) + r[x] = b - u + aux.append(u) + elif b.is_infinite: + if a: + u = next(ui) + r[x] = a + u + aux.append(u) + else: + # standard nonnegative relationship + pass + else: + u = next(ui) + aux.append(u) + # shift so u = x - a => x = u + a + r[x] = u + a + # add constraint for u <= b - a + # since when u = b-a then x = u + a = b - a + a = b: + # the upper limit for x + np.append(u - (b - a)) + + # make change of variables for unbound variables + for x in unbound: + u = next(ui) + r[x] = u - x # reusing x + aux.append(u) + + return np, r, aux + + +def _lp_matrices(objective, constraints): + """return A, B, C, D, r, x+X, X for maximizing + objective = Cx - D with constraints Ax <= B, introducing + introducing auxilliary variables, X, as necessary to make + replacements of symbols as given in r, {xi: expression with Xj}, + so all variables in x+X will take on nonnegative values. + + Every univariate condition creates a semi-infinite + condition, e.g. a single ``x <= 3`` creates the + interval ``[-oo, 3]`` while ``x <= 3`` and ``x >= 2`` + create an interval ``[2, 3]``. Variables not in a univariate + expression will take on nonnegative values. + """ + + # sympify input and collect free symbols + F = sympify(objective) + np = [sympify(i) for i in constraints] + syms = set.union(*[i.free_symbols for i in [F] + np], set()) + + # change Eq(x, y) to x - y <= 0 and y - x <= 0 + for i in range(len(np)): + if isinstance(np[i], Eq): + np[i] = np[i].lhs - np[i].rhs <= 0 + np.append(-np[i].lhs <= 0) + + # convert constraints to nonpositive expressions + _ = _rel_as_nonpos(np, syms) + if _ is None: + raise InfeasibleLPError(filldedent(""" + Inconsistent/False constraint""")) + np, r, aux = _ + + # do change of variables + F = F.xreplace(r) + np = [i.xreplace(r) for i in np] + + # convert to matrices + xx = list(ordered(syms)) + aux + A, B = linear_eq_to_matrix(np, xx) + C, D = linear_eq_to_matrix([F], xx) + return A, B, C, D, r, xx, aux + + +def _lp(min_max, f, constr): + """Return the optimization (min or max) of ``f`` with the given + constraints. All variables are unbounded unless constrained. + + If `min_max` is 'max' then the results corresponding to the + maximization of ``f`` will be returned, else the minimization. + The constraints can be given as Le, Ge or Eq expressions. + + Examples + ======== + + >>> from sympy.solvers.simplex import _lp as lp + >>> from sympy import Eq + >>> from sympy.abc import x, y, z + >>> f = x + y - 2*z + >>> c = [7*x + 4*y - 7*z <= 3, 3*x - y + 10*z <= 6] + >>> c += [i >= 0 for i in (x, y, z)] + >>> lp(min, f, c) + (-6/5, {x: 0, y: 0, z: 3/5}) + + By passing max, the maximum value for f under the constraints + is returned (if possible): + + >>> lp(max, f, c) + (3/4, {x: 0, y: 3/4, z: 0}) + + Constraints that are equalities will require that the solution + also satisfy them: + + >>> lp(max, f, c + [Eq(y - 9*x, 1)]) + (5/7, {x: 0, y: 1, z: 1/7}) + + All symbols are reported, even if they are not in the objective + function: + + >>> lp(min, x, [y + x >= 3, x >= 0]) + (0, {x: 0, y: 3}) + """ + # get the matrix components for the system expressed + # in terms of only nonnegative variables + A, B, C, D, r, xx, aux = _lp_matrices(f, constr) + + how = str(min_max).lower() + if "max" in how: + # _simplex minimizes for Ax <= B so we + # have to change the sign of the function + # and negate the optimal value returned + _o, p, d = _simplex(A, B, -C, -D) + o = -_o + elif "min" in how: + o, p, d = _simplex(A, B, C, D) + else: + raise ValueError("expecting min or max") + + # restore original variables and remove aux from p + p = dict(zip(xx, p)) + if r: # p has original symbols and auxilliary symbols + # if r has x: x - z1 use values from p to update + r = {k: v.xreplace(p) for k, v in r.items()} + # then use the actual value of x (= x - z1) in p + p.update(r) + # don't show aux + p = {k: p[k] for k in ordered(p) if k not in aux} + + # not returning dual since there may be extra constraints + # when a variable has finite bounds + return o, p + + +def lpmin(f, constr): + """return minimum of linear equation ``f`` under + linear constraints expressed using Ge, Le or Eq. + + All variables are unbounded unless constrained. + + Examples + ======== + + >>> from sympy.solvers.simplex import lpmin + >>> from sympy import Eq + >>> from sympy.abc import x, y + >>> lpmin(x, [2*x - 3*y >= -1, Eq(x + 3*y, 2), x <= 2*y]) + (1/3, {x: 1/3, y: 5/9}) + + Negative values for variables are permitted unless explicitly + excluding, so minimizing ``x`` for ``x <= 3`` is an + unbounded problem while the following has a bounded solution: + + >>> lpmin(x, [x >= 0, x <= 3]) + (0, {x: 0}) + + Without indicating that ``x`` is nonnegative, there + is no minimum for this objective: + + >>> lpmin(x, [x <= 3]) + Traceback (most recent call last): + ... + sympy.solvers.simplex.UnboundedLPError: + Objective function can assume arbitrarily large values! + + See Also + ======== + linprog, lpmax + """ + return _lp(min, f, constr) + + +def lpmax(f, constr): + """return maximum of linear equation ``f`` under + linear constraints expressed using Ge, Le or Eq. + + All variables are unbounded unless constrained. + + Examples + ======== + + >>> from sympy.solvers.simplex import lpmax + >>> from sympy import Eq + >>> from sympy.abc import x, y + >>> lpmax(x, [2*x - 3*y >= -1, Eq(x+ 3*y,2), x <= 2*y]) + (4/5, {x: 4/5, y: 2/5}) + + Negative values for variables are permitted unless explicitly + excluding: + + >>> lpmax(x, [x <= -1]) + (-1, {x: -1}) + + If a non-negative constraint is added for x, there is no + possible solution: + + >>> lpmax(x, [x <= -1, x >= 0]) + Traceback (most recent call last): + ... + sympy.solvers.simplex.InfeasibleLPError: inconsistent/False constraint + + See Also + ======== + linprog, lpmin + """ + return _lp(max, f, constr) + + +def _handle_bounds(bounds): + # introduce auxiliary variables as needed for univariate + # inequalities + + def _make_list(length: int, index_value_pairs): + li = [0] * length + for idx, val in index_value_pairs: + li[idx] = val + return li + + unbound = [] + row = [] + row2 = [] + b_len = len(bounds) + for x, (a, b) in enumerate(bounds): + if a is None and b is None: + unbound.append(x) + elif a is None: + # r[x] = b - u + b_len += 1 + row.append(_make_list(b_len, [(x, 1), (-1, 1)])) + row.append(_make_list(b_len, [(x, -1), (-1, -1)])) + row2.extend([[b], [-b]]) + elif b is None: + if a: + # r[x] = a + u + b_len += 1 + row.append(_make_list(b_len, [(x, 1), (-1, -1)])) + row.append(_make_list(b_len, [(x, -1), (-1, 1)])) + row2.extend([[a], [-a]]) + else: + # standard nonnegative relationship + pass + else: + # r[x] = u + a + b_len += 1 + row.append(_make_list(b_len, [(x, 1), (-1, -1)])) + row.append(_make_list(b_len, [(x, -1), (-1, 1)])) + # u <= b - a + row.append(_make_list(b_len, [(-1, 1)])) + row2.extend([[a], [-a], [b - a]]) + + # make change of variables for unbound variables + for x in unbound: + # r[x] = u - v + b_len += 2 + row.append(_make_list(b_len, [(x, 1), (-1, 1), (-2, -1)])) + row.append(_make_list(b_len, [(x, -1), (-1, -1), (-2, 1)])) + row2.extend([[0], [0]]) + + return Matrix([r + [0]*(b_len - len(r)) for r in row]), Matrix(row2) + + +def linprog(c, A=None, b=None, A_eq=None, b_eq=None, bounds=None): + """Return the minimization of ``c*x`` with the given + constraints ``A*x <= b`` and ``A_eq*x = b_eq``. Unless bounds + are given, variables will have nonnegative values in the solution. + + If ``A`` is not given, then the dimension of the system will + be determined by the length of ``C``. + + By default, all variables will be nonnegative. If ``bounds`` + is given as a single tuple, ``(lo, hi)``, then all variables + will be constrained to be between ``lo`` and ``hi``. Use + None for a ``lo`` or ``hi`` if it is unconstrained in the + negative or positive direction, respectively, e.g. + ``(None, 0)`` indicates nonpositive values. To set + individual ranges, pass a list with length equal to the + number of columns in ``A``, each element being a tuple; if + only a few variables take on non-default values they can be + passed as a dictionary with keys giving the corresponding + column to which the variable is assigned, e.g. ``bounds={2: + (1, 4)}`` would limit the 3rd variable to have a value in + range ``[1, 4]``. + + Examples + ======== + + >>> from sympy.solvers.simplex import linprog + >>> from sympy import symbols, Eq, linear_eq_to_matrix as M, Matrix + >>> x = x1, x2, x3, x4 = symbols('x1:5') + >>> X = Matrix(x) + >>> c, d = M(5*x2 + x3 + 4*x4 - x1, x) + >>> a, b = M([5*x2 + 2*x3 + 5*x4 - (x1 + 5)], x) + >>> aeq, beq = M([Eq(3*x2 + x4, 2), Eq(-x1 + x3 + 2*x4, 1)], x) + >>> constr = [i <= j for i,j in zip(a*X, b)] + >>> constr += [Eq(i, j) for i,j in zip(aeq*X, beq)] + >>> linprog(c, a, b, aeq, beq) + (9/2, [0, 1/2, 0, 1/2]) + >>> assert all(i.subs(dict(zip(x, _[1]))) for i in constr) + + See Also + ======== + lpmin, lpmax + """ + + ## the objective + C = Matrix(c) + if C.rows != 1 and C.cols == 1: + C = C.T + if C.rows != 1: + raise ValueError("C must be a single row.") + + ## the inequalities + if not A: + if b: + raise ValueError("A and b must both be given") + # the governing equations will be simple constraints + # on variables + A, b = zeros(0, C.cols), zeros(C.cols, 1) + else: + A, b = [Matrix(i) for i in (A, b)] + + if A.cols != C.cols: + raise ValueError("number of columns in A and C must match") + + ## the equalities + if A_eq is None: + if not b_eq is None: + raise ValueError("A_eq and b_eq must both be given") + else: + A_eq, b_eq = [Matrix(i) for i in (A_eq, b_eq)] + # if x == y then x <= y and x >= y (-x <= -y) + A = A.col_join(A_eq) + A = A.col_join(-A_eq) + b = b.col_join(b_eq) + b = b.col_join(-b_eq) + + if not (bounds is None or bounds == {} or bounds == (0, None)): + ## the bounds are interpreted + if type(bounds) is tuple and len(bounds) == 2: + bounds = [bounds] * A.cols + elif len(bounds) == A.cols and all( + type(i) is tuple and len(i) == 2 for i in bounds): + pass # individual bounds + elif type(bounds) is dict and all( + type(i) is tuple and len(i) == 2 + for i in bounds.values()): + # sparse bounds + db = bounds + bounds = [(0, None)] * A.cols + while db: + i, j = db.popitem() + bounds[i] = j # IndexError if out-of-bounds indices + else: + raise ValueError("unexpected bounds %s" % bounds) + A_, b_ = _handle_bounds(bounds) + aux = A_.cols - A.cols + if A: + A = Matrix([[A, zeros(A.rows, aux)], [A_]]) + b = b.col_join(b_) + else: + A = A_ + b = b_ + C = C.row_join(zeros(1, aux)) + else: + aux = -A.cols # set so -aux will give all cols below + + o, p, d = _simplex(A, b, C) + return o, p[:-aux] # don't include aux values + +def show_linprog(c, A=None, b=None, A_eq=None, b_eq=None, bounds=None): + from sympy import symbols + ## the objective + C = Matrix(c) + if C.rows != 1 and C.cols == 1: + C = C.T + if C.rows != 1: + raise ValueError("C must be a single row.") + + ## the inequalities + if not A: + if b: + raise ValueError("A and b must both be given") + # the governing equations will be simple constraints + # on variables + A, b = zeros(0, C.cols), zeros(C.cols, 1) + else: + A, b = [Matrix(i) for i in (A, b)] + + if A.cols != C.cols: + raise ValueError("number of columns in A and C must match") + + ## the equalities + if A_eq is None: + if not b_eq is None: + raise ValueError("A_eq and b_eq must both be given") + else: + A_eq, b_eq = [Matrix(i) for i in (A_eq, b_eq)] + + if not (bounds is None or bounds == {} or bounds == (0, None)): + ## the bounds are interpreted + if type(bounds) is tuple and len(bounds) == 2: + bounds = [bounds] * A.cols + elif len(bounds) == A.cols and all( + type(i) is tuple and len(i) == 2 for i in bounds): + pass # individual bounds + elif type(bounds) is dict and all( + type(i) is tuple and len(i) == 2 + for i in bounds.values()): + # sparse bounds + db = bounds + bounds = [(0, None)] * A.cols + while db: + i, j = db.popitem() + bounds[i] = j # IndexError if out-of-bounds indices + else: + raise ValueError("unexpected bounds %s" % bounds) + + x = Matrix(symbols('x1:%s' % (A.cols+1))) + f,c = (C*x)[0], [i<=j for i,j in zip(A*x, b)] + [Eq(i,j) for i,j in zip(A_eq*x,b_eq)] + for i, (lo, hi) in enumerate(bounds): + if lo is not None: + c.append(x[i]>=lo) + if hi is not None: + c.append(x[i]<=hi) + return f,c diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/solvers.py b/.venv/lib/python3.13/site-packages/sympy/solvers/solvers.py new file mode 100644 index 0000000000000000000000000000000000000000..ef621a84e34bde43a3181d2fd90e26fa7b05e968 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/solvers.py @@ -0,0 +1,3674 @@ +""" +This module contain solvers for all kinds of equations: + + - algebraic or transcendental, use solve() + + - recurrence, use rsolve() + + - differential, use dsolve() + + - nonlinear (numerically), use nsolve() + (you will need a good starting point) + +""" +from __future__ import annotations + +from sympy.core import (S, Add, Symbol, Dummy, Expr, Mul) +from sympy.core.assumptions import check_assumptions +from sympy.core.exprtools import factor_terms +from sympy.core.function import (expand_mul, expand_log, Derivative, + AppliedUndef, UndefinedFunction, nfloat, + Function, expand_power_exp, _mexpand, expand, + expand_func) +from sympy.core.logic import fuzzy_not, fuzzy_and +from sympy.core.numbers import Float, Rational, _illegal +from sympy.core.intfunc import integer_log, ilcm +from sympy.core.power import Pow +from sympy.core.relational import Eq, Ne +from sympy.core.sorting import ordered, default_sort_key +from sympy.core.sympify import sympify, _sympify +from sympy.core.traversal import preorder_traversal +from sympy.logic.boolalg import And, BooleanAtom + +from sympy.functions import (log, exp, LambertW, cos, sin, tan, acos, asin, atan, + Abs, re, im, arg, sqrt, atan2) +from sympy.functions.combinatorial.factorials import binomial +from sympy.functions.elementary.hyperbolic import HyperbolicFunction +from sympy.functions.elementary.piecewise import piecewise_fold, Piecewise +from sympy.functions.elementary.trigonometric import TrigonometricFunction +from sympy.integrals.integrals import Integral +from sympy.ntheory.factor_ import divisors +from sympy.simplify import (simplify, collect, powsimp, posify, # type: ignore + powdenest, nsimplify, denom, logcombine, sqrtdenest, fraction, + separatevars) +from sympy.simplify.sqrtdenest import sqrt_depth +from sympy.simplify.fu import TR1, TR2i, TR10, TR11 +from sympy.strategies.rl import rebuild +from sympy.matrices.exceptions import NonInvertibleMatrixError +from sympy.matrices import Matrix, zeros +from sympy.polys import roots, cancel, factor, Poly +from sympy.polys.solvers import sympy_eqs_to_ring, solve_lin_sys +from sympy.polys.polyerrors import GeneratorsNeeded, PolynomialError +from sympy.polys.polytools import gcd +from sympy.utilities.lambdify import lambdify +from sympy.utilities.misc import filldedent, debugf +from sympy.utilities.iterables import (connected_components, + generate_bell, uniq, iterable, is_sequence, subsets, flatten, sift) +from sympy.utilities.decorator import conserve_mpmath_dps + +from mpmath import findroot + +from sympy.solvers.polysys import solve_poly_system + +from types import GeneratorType +from collections import defaultdict +from itertools import combinations, product + +import warnings + + +def recast_to_symbols(eqs, symbols): + """ + Return (e, s, d) where e and s are versions of *eqs* and + *symbols* in which any non-Symbol objects in *symbols* have + been replaced with generic Dummy symbols and d is a dictionary + that can be used to restore the original expressions. + + Examples + ======== + + >>> from sympy.solvers.solvers import recast_to_symbols + >>> from sympy import symbols, Function + >>> x, y = symbols('x y') + >>> fx = Function('f')(x) + >>> eqs, syms = [fx + 1, x, y], [fx, y] + >>> e, s, d = recast_to_symbols(eqs, syms); (e, s, d) + ([_X0 + 1, x, y], [_X0, y], {_X0: f(x)}) + + The original equations and symbols can be restored using d: + + >>> assert [i.xreplace(d) for i in eqs] == eqs + >>> assert [d.get(i, i) for i in s] == syms + + """ + if not iterable(eqs) and iterable(symbols): + raise ValueError('Both eqs and symbols must be iterable') + orig = list(symbols) + symbols = list(ordered(symbols)) + swap_sym = {} + i = 0 + for s in symbols: + if not isinstance(s, Symbol) and s not in swap_sym: + swap_sym[s] = Dummy('X%d' % i) + i += 1 + new_f = [] + for i in eqs: + isubs = getattr(i, 'subs', None) + if isubs is not None: + new_f.append(isubs(swap_sym)) + else: + new_f.append(i) + restore = {v: k for k, v in swap_sym.items()} + return new_f, [swap_sym.get(i, i) for i in orig], restore + + +def _ispow(e): + """Return True if e is a Pow or is exp.""" + return isinstance(e, Expr) and (e.is_Pow or isinstance(e, exp)) + + +def _simple_dens(f, symbols): + # when checking if a denominator is zero, we can just check the + # base of powers with nonzero exponents since if the base is zero + # the power will be zero, too. To keep it simple and fast, we + # limit simplification to exponents that are Numbers + dens = set() + for d in denoms(f, symbols): + if d.is_Pow and d.exp.is_Number: + if d.exp.is_zero: + continue # foo**0 is never 0 + d = d.base + dens.add(d) + return dens + + +def denoms(eq, *symbols): + """ + Return (recursively) set of all denominators that appear in *eq* + that contain any symbol in *symbols*; if *symbols* are not + provided then all denominators will be returned. + + Examples + ======== + + >>> from sympy.solvers.solvers import denoms + >>> from sympy.abc import x, y, z + + >>> denoms(x/y) + {y} + + >>> denoms(x/(y*z)) + {y, z} + + >>> denoms(3/x + y/z) + {x, z} + + >>> denoms(x/2 + y/z) + {2, z} + + If *symbols* are provided then only denominators containing + those symbols will be returned: + + >>> denoms(1/x + 1/y + 1/z, y, z) + {y, z} + + """ + + pot = preorder_traversal(eq) + dens = set() + for p in pot: + # Here p might be Tuple or Relational + # Expr subtrees (e.g. lhs and rhs) will be traversed after by pot + if not isinstance(p, Expr): + continue + den = denom(p) + if den is S.One: + continue + dens.update(Mul.make_args(den)) + if not symbols: + return dens + elif len(symbols) == 1: + if iterable(symbols[0]): + symbols = symbols[0] + return {d for d in dens if any(s in d.free_symbols for s in symbols)} + + +def checksol(f, symbol, sol=None, **flags): + """ + Checks whether sol is a solution of equation f == 0. + + Explanation + =========== + + Input can be either a single symbol and corresponding value + or a dictionary of symbols and values. When given as a dictionary + and flag ``simplify=True``, the values in the dictionary will be + simplified. *f* can be a single equation or an iterable of equations. + A solution must satisfy all equations in *f* to be considered valid; + if a solution does not satisfy any equation, False is returned; if one or + more checks are inconclusive (and none are False) then None is returned. + + Examples + ======== + + >>> from sympy import checksol, symbols + >>> x, y = symbols('x,y') + >>> checksol(x**4 - 1, x, 1) + True + >>> checksol(x**4 - 1, x, 0) + False + >>> checksol(x**2 + y**2 - 5**2, {x: 3, y: 4}) + True + + To check if an expression is zero using ``checksol()``, pass it + as *f* and send an empty dictionary for *symbol*: + + >>> checksol(x**2 + x - x*(x + 1), {}) + True + + None is returned if ``checksol()`` could not conclude. + + flags: + 'numerical=True (default)' + do a fast numerical check if ``f`` has only one symbol. + 'minimal=True (default is False)' + a very fast, minimal testing. + 'warn=True (default is False)' + show a warning if checksol() could not conclude. + 'simplify=True (default)' + simplify solution before substituting into function and + simplify the function before trying specific simplifications + 'force=True (default is False)' + make positive all symbols without assumptions regarding sign. + + """ + from sympy.physics.units import Unit + + minimal = flags.get('minimal', False) + + if sol is not None: + sol = {symbol: sol} + elif isinstance(symbol, dict): + sol = symbol + else: + msg = 'Expecting (sym, val) or ({sym: val}, None) but got (%s, %s)' + raise ValueError(msg % (symbol, sol)) + + if iterable(f): + if not f: + raise ValueError('no functions to check') + return fuzzy_and(checksol(fi, sol, **flags) for fi in f) + + f = _sympify(f) + + if f.is_number: + return f.is_zero + + if isinstance(f, Poly): + f = f.as_expr() + elif isinstance(f, (Eq, Ne)): + if f.rhs in (S.true, S.false): + f = f.reversed + B, E = f.args + if isinstance(B, BooleanAtom): + f = f.subs(sol) + if not f.is_Boolean: + return + elif isinstance(f, Eq): + f = Add(f.lhs, -f.rhs, evaluate=False) + + if isinstance(f, BooleanAtom): + return bool(f) + elif not f.is_Relational and not f: + return True + + illegal = set(_illegal) + if any(sympify(v).atoms() & illegal for k, v in sol.items()): + return False + + attempt = -1 + numerical = flags.get('numerical', True) + while 1: + attempt += 1 + if attempt == 0: + val = f.subs(sol) + if isinstance(val, Mul): + val = val.as_independent(Unit)[0] + if val.atoms() & illegal: + return False + elif attempt == 1: + if not val.is_number: + if not val.is_constant(*list(sol.keys()), simplify=not minimal): + return False + # there are free symbols -- simple expansion might work + _, val = val.as_content_primitive() + val = _mexpand(val.as_numer_denom()[0], recursive=True) + elif attempt == 2: + if minimal: + return + if flags.get('simplify', True): + for k in sol: + sol[k] = simplify(sol[k]) + # start over without the failed expanded form, possibly + # with a simplified solution + val = simplify(f.subs(sol)) + if flags.get('force', True): + val, reps = posify(val) + # expansion may work now, so try again and check + exval = _mexpand(val, recursive=True) + if exval.is_number: + # we can decide now + val = exval + else: + # if there are no radicals and no functions then this can't be + # zero anymore -- can it? + pot = preorder_traversal(expand_mul(val)) + seen = set() + saw_pow_func = False + for p in pot: + if p in seen: + continue + seen.add(p) + if p.is_Pow and not p.exp.is_Integer: + saw_pow_func = True + elif p.is_Function: + saw_pow_func = True + elif isinstance(p, UndefinedFunction): + saw_pow_func = True + if saw_pow_func: + break + if saw_pow_func is False: + return False + if flags.get('force', True): + # don't do a zero check with the positive assumptions in place + val = val.subs(reps) + nz = fuzzy_not(val.is_zero) + if nz is not None: + # issue 5673: nz may be True even when False + # so these are just hacks to keep a false positive + # from being returned + + # HACK 1: LambertW (issue 5673) + if val.is_number and val.has(LambertW): + # don't eval this to verify solution since if we got here, + # numerical must be False + return None + + # add other HACKs here if necessary, otherwise we assume + # the nz value is correct + return not nz + break + if val.is_Rational: + return val == 0 + if numerical and val.is_number: + return (abs(val.n(18).n(12, chop=True)) < 1e-9) is S.true + + if flags.get('warn', False): + warnings.warn("\n\tWarning: could not verify solution %s." % sol) + # returns None if it can't conclude + # TODO: improve solution testing + + +def solve(f, *symbols, **flags): + r""" + Algebraically solves equations and systems of equations. + + Explanation + =========== + + Currently supported: + - polynomial + - transcendental + - piecewise combinations of the above + - systems of linear and polynomial equations + - systems containing relational expressions + - systems implied by undetermined coefficients + + Examples + ======== + + The default output varies according to the input and might + be a list (possibly empty), a dictionary, a list of + dictionaries or tuples, or an expression involving relationals. + For specifics regarding different forms of output that may appear, see :ref:`solve_output`. + Let it suffice here to say that to obtain a uniform output from + `solve` use ``dict=True`` or ``set=True`` (see below). + + >>> from sympy import solve, Poly, Eq, Matrix, Symbol + >>> from sympy.abc import x, y, z, a, b + + The expressions that are passed can be Expr, Equality, or Poly + classes (or lists of the same); a Matrix is considered to be a + list of all the elements of the matrix: + + >>> solve(x - 3, x) + [3] + >>> solve(Eq(x, 3), x) + [3] + >>> solve(Poly(x - 3), x) + [3] + >>> solve(Matrix([[x, x + y]]), x, y) == solve([x, x + y], x, y) + True + + If no symbols are indicated to be of interest and the equation is + univariate, a list of values is returned; otherwise, the keys in + a dictionary will indicate which (of all the variables used in + the expression(s)) variables and solutions were found: + + >>> solve(x**2 - 4) + [-2, 2] + >>> solve((x - a)*(y - b)) + [{a: x}, {b: y}] + >>> solve([x - 3, y - 1]) + {x: 3, y: 1} + >>> solve([x - 3, y**2 - 1]) + [{x: 3, y: -1}, {x: 3, y: 1}] + + If you pass symbols for which solutions are sought, the output will vary + depending on the number of symbols you passed, whether you are passing + a list of expressions or not, and whether a linear system was solved. + Uniform output is attained by using ``dict=True`` or ``set=True``. + + >>> #### *** feel free to skip to the stars below *** #### + >>> from sympy import TableForm + >>> h = [None, ';|;'.join(['e', 's', 'solve(e, s)', 'solve(e, s, dict=True)', + ... 'solve(e, s, set=True)']).split(';')] + >>> t = [] + >>> for e, s in [ + ... (x - y, y), + ... (x - y, [x, y]), + ... (x**2 - y, [x, y]), + ... ([x - 3, y -1], [x, y]), + ... ]: + ... how = [{}, dict(dict=True), dict(set=True)] + ... res = [solve(e, s, **f) for f in how] + ... t.append([e, '|', s, '|'] + [res[0], '|', res[1], '|', res[2]]) + ... + >>> # ******************************************************* # + >>> TableForm(t, headings=h, alignments="<") + e | s | solve(e, s) | solve(e, s, dict=True) | solve(e, s, set=True) + --------------------------------------------------------------------------------------- + x - y | y | [x] | [{y: x}] | ([y], {(x,)}) + x - y | [x, y] | [(y, y)] | [{x: y}] | ([x, y], {(y, y)}) + x**2 - y | [x, y] | [(x, x**2)] | [{y: x**2}] | ([x, y], {(x, x**2)}) + [x - 3, y - 1] | [x, y] | {x: 3, y: 1} | [{x: 3, y: 1}] | ([x, y], {(3, 1)}) + + * If any equation does not depend on the symbol(s) given, it will be + eliminated from the equation set and an answer may be given + implicitly in terms of variables that were not of interest: + + >>> solve([x - y, y - 3], x) + {x: y} + + When you pass all but one of the free symbols, an attempt + is made to find a single solution based on the method of + undetermined coefficients. If it succeeds, a dictionary of values + is returned. If you want an algebraic solutions for one + or more of the symbols, pass the expression to be solved in a list: + + >>> e = a*x + b - 2*x - 3 + >>> solve(e, [a, b]) + {a: 2, b: 3} + >>> solve([e], [a, b]) + {a: -b/x + (2*x + 3)/x} + + When there is no solution for any given symbol which will make all + expressions zero, the empty list is returned (or an empty set in + the tuple when ``set=True``): + + >>> from sympy import sqrt + >>> solve(3, x) + [] + >>> solve(x - 3, y) + [] + >>> solve(sqrt(x) + 1, x, set=True) + ([x], set()) + + When an object other than a Symbol is given as a symbol, it is + isolated algebraically and an implicit solution may be obtained. + This is mostly provided as a convenience to save you from replacing + the object with a Symbol and solving for that Symbol. It will only + work if the specified object can be replaced with a Symbol using the + subs method: + + >>> from sympy import exp, Function + >>> f = Function('f') + + >>> solve(f(x) - x, f(x)) + [x] + >>> solve(f(x).diff(x) - f(x) - x, f(x).diff(x)) + [x + f(x)] + >>> solve(f(x).diff(x) - f(x) - x, f(x)) + [-x + Derivative(f(x), x)] + >>> solve(x + exp(x)**2, exp(x), set=True) + ([exp(x)], {(-sqrt(-x),), (sqrt(-x),)}) + + >>> from sympy import Indexed, IndexedBase, Tuple + >>> A = IndexedBase('A') + >>> eqs = Tuple(A[1] + A[2] - 3, A[1] - A[2] + 1) + >>> solve(eqs, eqs.atoms(Indexed)) + {A[1]: 1, A[2]: 2} + + * To solve for a function within a derivative, use :func:`~.dsolve`. + + To solve for a symbol implicitly, use implicit=True: + + >>> solve(x + exp(x), x) + [-LambertW(1)] + >>> solve(x + exp(x), x, implicit=True) + [-exp(x)] + + It is possible to solve for anything in an expression that can be + replaced with a symbol using :obj:`~sympy.core.basic.Basic.subs`: + + >>> solve(x + 2 + sqrt(3), x + 2) + [-sqrt(3)] + >>> solve((x + 2 + sqrt(3), x + 4 + y), y, x + 2) + {y: -2 + sqrt(3), x + 2: -sqrt(3)} + + * Nothing heroic is done in this implicit solving so you may end up + with a symbol still in the solution: + + >>> eqs = (x*y + 3*y + sqrt(3), x + 4 + y) + >>> solve(eqs, y, x + 2) + {y: -sqrt(3)/(x + 3), x + 2: -2*x/(x + 3) - 6/(x + 3) + sqrt(3)/(x + 3)} + >>> solve(eqs, y*x, x) + {x: -y - 4, x*y: -3*y - sqrt(3)} + + * If you attempt to solve for a number, remember that the number + you have obtained does not necessarily mean that the value is + equivalent to the expression obtained: + + >>> solve(sqrt(2) - 1, 1) + [sqrt(2)] + >>> solve(x - y + 1, 1) # /!\ -1 is targeted, too + [x/(y - 1)] + >>> [_.subs(z, -1) for _ in solve((x - y + 1).subs(-1, z), 1)] + [-x + y] + + **Additional Examples** + + ``solve()`` with check=True (default) will run through the symbol tags to + eliminate unwanted solutions. If no assumptions are included, all possible + solutions will be returned: + + >>> x = Symbol("x") + >>> solve(x**2 - 1) + [-1, 1] + + By setting the ``positive`` flag, only one solution will be returned: + + >>> pos = Symbol("pos", positive=True) + >>> solve(pos**2 - 1) + [1] + + When the solutions are checked, those that make any denominator zero + are automatically excluded. If you do not want to exclude such solutions, + then use the check=False option: + + >>> from sympy import sin, limit + >>> solve(sin(x)/x) # 0 is excluded + [pi] + + If ``check=False``, then a solution to the numerator being zero is found + but the value of $x = 0$ is a spurious solution since $\sin(x)/x$ has the well + known limit (without discontinuity) of 1 at $x = 0$: + + >>> solve(sin(x)/x, check=False) + [0, pi] + + In the following case, however, the limit exists and is equal to the + value of $x = 0$ that is excluded when check=True: + + >>> eq = x**2*(1/x - z**2/x) + >>> solve(eq, x) + [] + >>> solve(eq, x, check=False) + [0] + >>> limit(eq, x, 0, '-') + 0 + >>> limit(eq, x, 0, '+') + 0 + + **Solving Relationships** + + When one or more expressions passed to ``solve`` is a relational, + a relational result is returned (and the ``dict`` and ``set`` flags + are ignored): + + >>> solve(x < 3) + (-oo < x) & (x < 3) + >>> solve([x < 3, x**2 > 4], x) + ((-oo < x) & (x < -2)) | ((2 < x) & (x < 3)) + >>> solve([x + y - 3, x > 3], x) + (3 < x) & (x < oo) & Eq(x, 3 - y) + + Although checking of assumptions on symbols in relationals + is not done, setting assumptions will affect how certain + relationals might automatically simplify: + + >>> solve(x**2 > 4) + ((-oo < x) & (x < -2)) | ((2 < x) & (x < oo)) + + >>> r = Symbol('r', real=True) + >>> solve(r**2 > 4) + (2 < r) | (r < -2) + + There is currently no algorithm in SymPy that allows you to use + relationships to resolve more than one variable. So the following + does not determine that ``q < 0`` (and trying to solve for ``r`` + and ``q`` will raise an error): + + >>> from sympy import symbols + >>> r, q = symbols('r, q', real=True) + >>> solve([r + q - 3, r > 3], r) + (3 < r) & Eq(r, 3 - q) + + You can directly call the routine that ``solve`` calls + when it encounters a relational: :func:`~.reduce_inequalities`. + It treats Expr like Equality. + + >>> from sympy import reduce_inequalities + >>> reduce_inequalities([x**2 - 4]) + Eq(x, -2) | Eq(x, 2) + + If each relationship contains only one symbol of interest, + the expressions can be processed for multiple symbols: + + >>> reduce_inequalities([0 <= x - 1, y < 3], [x, y]) + (-oo < y) & (1 <= x) & (x < oo) & (y < 3) + + But an error is raised if any relationship has more than one + symbol of interest: + + >>> reduce_inequalities([0 <= x*y - 1, y < 3], [x, y]) + Traceback (most recent call last): + ... + NotImplementedError: + inequality has more than one symbol of interest. + + **Disabling High-Order Explicit Solutions** + + When solving polynomial expressions, you might not want explicit solutions + (which can be quite long). If the expression is univariate, ``CRootOf`` + instances will be returned instead: + + >>> solve(x**3 - x + 1) + [-1/((-1/2 - sqrt(3)*I/2)*(3*sqrt(69)/2 + 27/2)**(1/3)) - + (-1/2 - sqrt(3)*I/2)*(3*sqrt(69)/2 + 27/2)**(1/3)/3, + -(-1/2 + sqrt(3)*I/2)*(3*sqrt(69)/2 + 27/2)**(1/3)/3 - + 1/((-1/2 + sqrt(3)*I/2)*(3*sqrt(69)/2 + 27/2)**(1/3)), + -(3*sqrt(69)/2 + 27/2)**(1/3)/3 - + 1/(3*sqrt(69)/2 + 27/2)**(1/3)] + >>> solve(x**3 - x + 1, cubics=False) + [CRootOf(x**3 - x + 1, 0), + CRootOf(x**3 - x + 1, 1), + CRootOf(x**3 - x + 1, 2)] + + If the expression is multivariate, no solution might be returned: + + >>> solve(x**3 - x + a, x, cubics=False) + [] + + Sometimes solutions will be obtained even when a flag is False because the + expression could be factored. In the following example, the equation can + be factored as the product of a linear and a quadratic factor so explicit + solutions (which did not require solving a cubic expression) are obtained: + + >>> eq = x**3 + 3*x**2 + x - 1 + >>> solve(eq, cubics=False) + [-1, -1 + sqrt(2), -sqrt(2) - 1] + + **Solving Equations Involving Radicals** + + Because of SymPy's use of the principle root, some solutions + to radical equations will be missed unless check=False: + + >>> from sympy import root + >>> eq = root(x**3 - 3*x**2, 3) + 1 - x + >>> solve(eq) + [] + >>> solve(eq, check=False) + [1/3] + + In the above example, there is only a single solution to the + equation. Other expressions will yield spurious roots which + must be checked manually; roots which give a negative argument + to odd-powered radicals will also need special checking: + + >>> from sympy import real_root, S + >>> eq = root(x, 3) - root(x, 5) + S(1)/7 + >>> solve(eq) # this gives 2 solutions but misses a 3rd + [CRootOf(7*x**5 - 7*x**3 + 1, 1)**15, + CRootOf(7*x**5 - 7*x**3 + 1, 2)**15] + >>> sol = solve(eq, check=False) + >>> [abs(eq.subs(x,i).n(2)) for i in sol] + [0.48, 0.e-110, 0.e-110, 0.052, 0.052] + + The first solution is negative so ``real_root`` must be used to see that it + satisfies the expression: + + >>> abs(real_root(eq.subs(x, sol[0])).n(2)) + 0.e-110 + + If the roots of the equation are not real then more care will be + necessary to find the roots, especially for higher order equations. + Consider the following expression: + + >>> expr = root(x, 3) - root(x, 5) + + We will construct a known value for this expression at x = 3 by selecting + the 1-th root for each radical: + + >>> expr1 = root(x, 3, 1) - root(x, 5, 1) + >>> v = expr1.subs(x, -3) + + The ``solve`` function is unable to find any exact roots to this equation: + + >>> eq = Eq(expr, v); eq1 = Eq(expr1, v) + >>> solve(eq, check=False), solve(eq1, check=False) + ([], []) + + The function ``unrad``, however, can be used to get a form of the equation + for which numerical roots can be found: + + >>> from sympy.solvers.solvers import unrad + >>> from sympy import nroots + >>> e, (p, cov) = unrad(eq) + >>> pvals = nroots(e) + >>> inversion = solve(cov, x)[0] + >>> xvals = [inversion.subs(p, i) for i in pvals] + + Although ``eq`` or ``eq1`` could have been used to find ``xvals``, the + solution can only be verified with ``expr1``: + + >>> z = expr - v + >>> [xi.n(chop=1e-9) for xi in xvals if abs(z.subs(x, xi).n()) < 1e-9] + [] + >>> z1 = expr1 - v + >>> [xi.n(chop=1e-9) for xi in xvals if abs(z1.subs(x, xi).n()) < 1e-9] + [-3.0] + + Parameters + ========== + + f : + - a single Expr or Poly that must be zero + - an Equality + - a Relational expression + - a Boolean + - iterable of one or more of the above + + symbols : (object(s) to solve for) specified as + - none given (other non-numeric objects will be used) + - single symbol + - denested list of symbols + (e.g., ``solve(f, x, y)``) + - ordered iterable of symbols + (e.g., ``solve(f, [x, y])``) + + flags : + dict=True (default is False) + Return list (perhaps empty) of solution mappings. + set=True (default is False) + Return list of symbols and set of tuple(s) of solution(s). + exclude=[] (default) + Do not try to solve for any of the free symbols in exclude; + if expressions are given, the free symbols in them will + be extracted automatically. + check=True (default) + If False, do not do any testing of solutions. This can be + useful if you want to include solutions that make any + denominator zero. + numerical=True (default) + Do a fast numerical check if *f* has only one symbol. + minimal=True (default is False) + A very fast, minimal testing. + warn=True (default is False) + Show a warning if ``checksol()`` could not conclude. + simplify=True (default) + Simplify all but polynomials of order 3 or greater before + returning them and (if check is not False) use the + general simplify function on the solutions and the + expression obtained when they are substituted into the + function which should be zero. + force=True (default is False) + Make positive all symbols without assumptions regarding sign. + rational=True (default) + Recast Floats as Rational; if this option is not used, the + system containing Floats may fail to solve because of issues + with polys. If rational=None, Floats will be recast as + rationals but the answer will be recast as Floats. If the + flag is False then nothing will be done to the Floats. + manual=True (default is False) + Do not use the polys/matrix method to solve a system of + equations, solve them one at a time as you might "manually." + implicit=True (default is False) + Allows ``solve`` to return a solution for a pattern in terms of + other functions that contain that pattern; this is only + needed if the pattern is inside of some invertible function + like cos, exp, etc. + particular=True (default is False) + Instructs ``solve`` to try to find a particular solution to + a linear system with as many zeros as possible; this is very + expensive. + quick=True (default is False; ``particular`` must be True) + Selects a fast heuristic to find a solution with many zeros + whereas a value of False uses the very slow method guaranteed + to find the largest number of zeros possible. + cubics=True (default) + Return explicit solutions when cubic expressions are encountered. + When False, quartics and quintics are disabled, too. + quartics=True (default) + Return explicit solutions when quartic expressions are encountered. + When False, quintics are disabled, too. + quintics=True (default) + Return explicit solutions (if possible) when quintic expressions + are encountered. + + See Also + ======== + + rsolve: For solving recurrence relationships + sympy.solvers.ode.dsolve: For solving differential equations + + """ + from .inequalities import reduce_inequalities + + # checking/recording flags + ########################################################################### + + # set solver types explicitly; as soon as one is False + # all the rest will be False + hints = ('cubics', 'quartics', 'quintics') + default = True + for k in hints: + default = flags.setdefault(k, bool(flags.get(k, default))) + + # allow solution to contain symbol if True: + implicit = flags.get('implicit', False) + + # record desire to see warnings + warn = flags.get('warn', False) + + # this flag will be needed for quick exits below, so record + # now -- but don't record `dict` yet since it might change + as_set = flags.get('set', False) + + # keeping track of how f was passed + bare_f = not iterable(f) + + # check flag usage for particular/quick which should only be used + # with systems of equations + if flags.get('quick', None) is not None: + if not flags.get('particular', None): + raise ValueError('when using `quick`, `particular` should be True') + if flags.get('particular', False) and bare_f: + raise ValueError(filldedent(""" + The 'particular/quick' flag is usually used with systems of + equations. Either pass your equation in a list or + consider using a solver like `diophantine` if you are + looking for a solution in integers.""")) + + # sympify everything, creating list of expressions and list of symbols + ########################################################################### + + def _sympified_list(w): + return list(map(sympify, w if iterable(w) else [w])) + f, symbols = (_sympified_list(w) for w in [f, symbols]) + + # preprocess symbol(s) + ########################################################################### + + ordered_symbols = None # were the symbols in a well defined order? + if not symbols: + # get symbols from equations + symbols = set().union(*[fi.free_symbols for fi in f]) + if len(symbols) < len(f): + for fi in f: + pot = preorder_traversal(fi) + for p in pot: + if isinstance(p, AppliedUndef): + if not as_set: + flags['dict'] = True # better show symbols + symbols.add(p) + pot.skip() # don't go any deeper + ordered_symbols = False + symbols = list(ordered(symbols)) # to make it canonical + else: + if len(symbols) == 1 and iterable(symbols[0]): + symbols = symbols[0] + ordered_symbols = symbols and is_sequence(symbols, + include=GeneratorType) + _symbols = list(uniq(symbols)) + if len(_symbols) != len(symbols): + ordered_symbols = False + symbols = list(ordered(symbols)) + else: + symbols = _symbols + + # check for duplicates + if len(symbols) != len(set(symbols)): + raise ValueError('duplicate symbols given') + # remove those not of interest + exclude = flags.pop('exclude', set()) + if exclude: + if isinstance(exclude, Expr): + exclude = [exclude] + exclude = set().union(*[e.free_symbols for e in sympify(exclude)]) + symbols = [s for s in symbols if s not in exclude] + + # preprocess equation(s) + ########################################################################### + + # automatically ignore True values + if isinstance(f, list): + f = [s for s in f if s is not S.true] + + # handle canonicalization of equation types + for i, fi in enumerate(f): + if isinstance(fi, (Eq, Ne)): + if 'ImmutableDenseMatrix' in [type(a).__name__ for a in fi.args]: + fi = fi.lhs - fi.rhs + else: + L, R = fi.args + if isinstance(R, BooleanAtom): + L, R = R, L + if isinstance(L, BooleanAtom): + if isinstance(fi, Ne): + L = ~L + if R.is_Relational: + fi = ~R if L is S.false else R + elif R.is_Symbol: + return L + elif R.is_Boolean and (~R).is_Symbol: + return ~L + else: + raise NotImplementedError(filldedent(''' + Unanticipated argument of Eq when other arg + is True or False. + ''')) + elif isinstance(fi, Eq): + fi = Add(fi.lhs, -fi.rhs, evaluate=False) + f[i] = fi + + # *** dispatch and handle as a system of relationals + # ************************************************** + if fi.is_Relational: + if len(symbols) != 1: + raise ValueError("can only solve for one symbol at a time") + if warn and symbols[0].assumptions0: + warnings.warn(filldedent(""" + \tWarning: assumptions about variable '%s' are + not handled currently.""" % symbols[0])) + return reduce_inequalities(f, symbols=symbols) + + # convert Poly to expression + if isinstance(fi, Poly): + f[i] = fi.as_expr() + + # rewrite hyperbolics in terms of exp if they have symbols of + # interest + f[i] = f[i].replace(lambda w: isinstance(w, HyperbolicFunction) and \ + w.has_free(*symbols), lambda w: w.rewrite(exp)) + + # if we have a Matrix, we need to iterate over its elements again + if f[i].is_Matrix: + try: + f[i] = f[i].as_explicit() + except ValueError: + raise ValueError( + "solve cannot handle matrices with symbolic shape." + ) + bare_f = False + f.extend(list(f[i])) + f[i] = S.Zero + + # if we can split it into real and imaginary parts then do so + freei = f[i].free_symbols + if freei and all(s.is_extended_real or s.is_imaginary for s in freei): + fr, fi = f[i].as_real_imag() + # accept as long as new re, im, arg or atan2 are not introduced + had = f[i].atoms(re, im, arg, atan2) + if fr and fi and fr != fi and not any( + i.atoms(re, im, arg, atan2) - had for i in (fr, fi)): + if bare_f: + bare_f = False + f[i: i + 1] = [fr, fi] + + # real/imag handling ----------------------------- + if any(isinstance(fi, (bool, BooleanAtom)) for fi in f): + if as_set: + return [], set() + return [] + + for i, fi in enumerate(f): + # Abs + while True: + was = fi + fi = fi.replace(Abs, lambda arg: + separatevars(Abs(arg)).rewrite(Piecewise) if arg.has(*symbols) + else Abs(arg)) + if was == fi: + break + + for e in fi.find(Abs): + if e.has(*symbols): + raise NotImplementedError('solving %s when the argument ' + 'is not real or imaginary.' % e) + + # arg + fi = fi.replace(arg, lambda a: arg(a).rewrite(atan2).rewrite(atan)) + + # save changes + f[i] = fi + + # see if re(s) or im(s) appear + freim = [fi for fi in f if fi.has(re, im)] + if freim: + irf = [] + for s in symbols: + if s.is_real or s.is_imaginary: + continue # neither re(x) nor im(x) will appear + # if re(s) or im(s) appear, the auxiliary equation must be present + if any(fi.has(re(s), im(s)) for fi in freim): + irf.append((s, re(s) + S.ImaginaryUnit*im(s))) + if irf: + for s, rhs in irf: + f = [fi.xreplace({s: rhs}) for fi in f] + [s - rhs] + symbols.extend([re(s), im(s)]) + if bare_f: + bare_f = False + flags['dict'] = True + # end of real/imag handling ----------------------------- + + # we can solve for non-symbol entities by replacing them with Dummy symbols + f, symbols, swap_sym = recast_to_symbols(f, symbols) + # this set of symbols (perhaps recast) is needed below + symset = set(symbols) + + # get rid of equations that have no symbols of interest; we don't + # try to solve them because the user didn't ask and they might be + # hard to solve; this means that solutions may be given in terms + # of the eliminated equations e.g. solve((x-y, y-3), x) -> {x: y} + newf = [] + for fi in f: + # let the solver handle equations that.. + # - have no symbols but are expressions + # - have symbols of interest + # - have no symbols of interest but are constant + # but when an expression is not constant and has no symbols of + # interest, it can't change what we obtain for a solution from + # the remaining equations so we don't include it; and if it's + # zero it can be removed and if it's not zero, there is no + # solution for the equation set as a whole + # + # The reason for doing this filtering is to allow an answer + # to be obtained to queries like solve((x - y, y), x); without + # this mod the return value is [] + ok = False + if fi.free_symbols & symset: + ok = True + else: + if fi.is_number: + if fi.is_Number: + if fi.is_zero: + continue + return [] + ok = True + else: + if fi.is_constant(): + ok = True + if ok: + newf.append(fi) + if not newf: + if as_set: + return symbols, set() + return [] + f = newf + del newf + + # mask off any Object that we aren't going to invert: Derivative, + # Integral, etc... so that solving for anything that they contain will + # give an implicit solution + seen = set() + non_inverts = set() + for fi in f: + pot = preorder_traversal(fi) + for p in pot: + if not isinstance(p, Expr) or isinstance(p, Piecewise): + pass + elif (isinstance(p, bool) or + not p.args or + p in symset or + p.is_Add or p.is_Mul or + p.is_Pow and not implicit or + p.is_Function and not implicit) and p.func not in (re, im): + continue + elif p not in seen: + seen.add(p) + if p.free_symbols & symset: + non_inverts.add(p) + else: + continue + pot.skip() + del seen + non_inverts = dict(list(zip(non_inverts, [Dummy() for _ in non_inverts]))) + f = [fi.subs(non_inverts) for fi in f] + + # Both xreplace and subs are needed below: xreplace to force substitution + # inside Derivative, subs to handle non-straightforward substitutions + non_inverts = [(v, k.xreplace(swap_sym).subs(swap_sym)) for k, v in non_inverts.items()] + + # rationalize Floats + floats = False + if flags.get('rational', True) is not False: + for i, fi in enumerate(f): + if fi.has(Float): + floats = True + f[i] = nsimplify(fi, rational=True) + + # capture any denominators before rewriting since + # they may disappear after the rewrite, e.g. issue 14779 + flags['_denominators'] = _simple_dens(f[0], symbols) + + # Any embedded piecewise functions need to be brought out to the + # top level so that the appropriate strategy gets selected. + # However, this is necessary only if one of the piecewise + # functions depends on one of the symbols we are solving for. + def _has_piecewise(e): + if e.is_Piecewise: + return e.has(*symbols) + return any(_has_piecewise(a) for a in e.args) + for i, fi in enumerate(f): + if _has_piecewise(fi): + f[i] = piecewise_fold(fi) + + # expand angles of sums; in general, expand_trig will allow + # more roots to be found but this is not a great solultion + # to not returning a parametric solution, otherwise + # many values can be returned that have a simple + # relationship between values + targs = {t for fi in f for t in fi.atoms(TrigonometricFunction)} + if len(targs) > 1: + add, other = sift(targs, lambda x: x.args[0].is_Add, binary=True) + add, other = [[i for i in l if i.has_free(*symbols)] for l in (add, other)] + trep = {} + for t in add: + a = t.args[0] + ind, dep = a.as_independent(*symbols) + if dep in symbols or -dep in symbols: + # don't let expansion expand wrt anything in ind + n = Dummy() if not ind.is_Number else ind + trep[t] = TR10(t.func(dep + n)).xreplace({n: ind}) + if other and len(other) <= 2: + base = gcd(*[i.args[0] for i in other]) if len(other) > 1 else other[0].args[0] + for i in other: + trep[i] = TR11(i, base) + f = [fi.xreplace(trep) for fi in f] + + # + # try to get a solution + ########################################################################### + if bare_f: + solution = None + if len(symbols) != 1: + solution = _solve_undetermined(f[0], symbols, flags) + if not solution: + solution = _solve(f[0], *symbols, **flags) + else: + linear, solution = _solve_system(f, symbols, **flags) + assert type(solution) is list + assert not solution or type(solution[0]) is dict, solution + # + # postprocessing + ########################################################################### + # capture as_dict flag now (as_set already captured) + as_dict = flags.get('dict', False) + + # define how solution will get unpacked + tuple_format = lambda s: [tuple([i.get(x, x) for x in symbols]) for i in s] + if as_dict or as_set: + unpack = None + elif bare_f: + if len(symbols) == 1: + unpack = lambda s: [i[symbols[0]] for i in s] + elif len(solution) == 1 and len(solution[0]) == len(symbols): + # undetermined linear coeffs solution + unpack = lambda s: s[0] + elif ordered_symbols: + unpack = tuple_format + else: + unpack = lambda s: s + else: + if solution: + if linear and len(solution) == 1: + # if you want the tuple solution for the linear + # case, use `set=True` + unpack = lambda s: s[0] + elif ordered_symbols: + unpack = tuple_format + else: + unpack = lambda s: s + else: + unpack = None + + # Restore masked-off objects + if non_inverts and type(solution) is list: + solution = [{k: v.subs(non_inverts) for k, v in s.items()} + for s in solution] + + # Restore original "symbols" if a dictionary is returned. + # This is not necessary for + # - the single univariate equation case + # since the symbol will have been removed from the solution; + # - the nonlinear poly_system since that only supports zero-dimensional + # systems and those results come back as a list + # + # ** unless there were Derivatives with the symbols, but those were handled + # above. + if swap_sym: + symbols = [swap_sym.get(k, k) for k in symbols] + for i, sol in enumerate(solution): + solution[i] = {swap_sym.get(k, k): v.subs(swap_sym) + for k, v in sol.items()} + + # Get assumptions about symbols, to filter solutions. + # Note that if assumptions about a solution can't be verified, it is still + # returned. + check = flags.get('check', True) + + # restore floats + if floats and solution and flags.get('rational', None) is None: + solution = nfloat(solution, exponent=False) + # nfloat might reveal more duplicates + solution = _remove_duplicate_solutions(solution) + + if check and solution: # assumption checking + warn = flags.get('warn', False) + got_None = [] # solutions for which one or more symbols gave None + no_False = [] # solutions for which no symbols gave False + for sol in solution: + v = fuzzy_and(check_assumptions(val, **symb.assumptions0) + for symb, val in sol.items()) + if v is False: + continue + no_False.append(sol) + if v is None: + got_None.append(sol) + + solution = no_False + if warn and got_None: + warnings.warn(filldedent(""" + \tWarning: assumptions concerning following solution(s) + cannot be checked:""" + '\n\t' + + ', '.join(str(s) for s in got_None))) + + # + # done + ########################################################################### + + if not solution: + if as_set: + return symbols, set() + return [] + + # make orderings canonical for list of dictionaries + if not as_set: # for set, no point in ordering + solution = [{k: s[k] for k in ordered(s)} for s in solution] + solution.sort(key=default_sort_key) + + if not (as_set or as_dict): + return unpack(solution) + + if as_dict: + return solution + + # set output: (symbols, {t1, t2, ...}) from list of dictionaries; + # include all symbols for those that like a verbose solution + # and to resolve any differences in dictionary keys. + # + # The set results can easily be used to make a verbose dict as + # k, v = solve(eqs, syms, set=True) + # sol = [dict(zip(k,i)) for i in v] + # + if ordered_symbols: + k = symbols # keep preferred order + else: + # just unify the symbols for which solutions were found + k = list(ordered(set(flatten(tuple(i.keys()) for i in solution)))) + return k, {tuple([s.get(ki, ki) for ki in k]) for s in solution} + + +def _solve_undetermined(g, symbols, flags): + """solve helper to return a list with one dict (solution) else None + + A direct call to solve_undetermined_coeffs is more flexible and + can return both multiple solutions and handle more than one independent + variable. Here, we have to be more cautious to keep from solving + something that does not look like an undetermined coeffs system -- + to minimize the surprise factor since singularities that cancel are not + prohibited in solve_undetermined_coeffs. + """ + if g.free_symbols - set(symbols): + sol = solve_undetermined_coeffs(g, symbols, **dict(flags, dict=True, set=None)) + if len(sol) == 1: + return sol + + +def _solve(f, *symbols, **flags): + """Return a checked solution for *f* in terms of one or more of the + symbols in the form of a list of dictionaries. + + If no method is implemented to solve the equation, a NotImplementedError + will be raised. In the case that conversion of an expression to a Poly + gives None a ValueError will be raised. + """ + + not_impl_msg = "No algorithms are implemented to solve equation %s" + + if len(symbols) != 1: + # look for solutions for desired symbols that are independent + # of symbols already solved for, e.g. if we solve for x = y + # then no symbol having x in its solution will be returned. + + # First solve for linear symbols (since that is easier and limits + # solution size) and then proceed with symbols appearing + # in a non-linear fashion. Ideally, if one is solving a single + # expression for several symbols, they would have to be + # appear in factors of an expression, but we do not here + # attempt factorization. XXX perhaps handling a Mul + # should come first in this routine whether there is + # one or several symbols. + nonlin_s = [] + got_s = set() + rhs_s = set() + result = [] + for s in symbols: + xi, v = solve_linear(f, symbols=[s]) + if xi == s: + # no need to check but we should simplify if desired + if flags.get('simplify', True): + v = simplify(v) + vfree = v.free_symbols + if vfree & got_s: + # was linear, but has redundant relationship + # e.g. x - y = 0 has y == x is redundant for x == y + # so ignore + continue + rhs_s |= vfree + got_s.add(xi) + result.append({xi: v}) + elif xi: # there might be a non-linear solution if xi is not 0 + nonlin_s.append(s) + if not nonlin_s: + return result + for s in nonlin_s: + try: + soln = _solve(f, s, **flags) + for sol in soln: + if sol[s].free_symbols & got_s: + # depends on previously solved symbols: ignore + continue + got_s.add(s) + result.append(sol) + except NotImplementedError: + continue + if got_s: + return result + else: + raise NotImplementedError(not_impl_msg % f) + + # solve f for a single variable + + symbol = symbols[0] + + # expand binomials only if it has the unknown symbol + f = f.replace(lambda e: isinstance(e, binomial) and e.has(symbol), + lambda e: expand_func(e)) + + # checking will be done unless it is turned off before making a + # recursive call; the variables `checkdens` and `check` are + # captured here (for reference below) in case flag value changes + flags['check'] = checkdens = check = flags.pop('check', True) + + # build up solutions if f is a Mul + if f.is_Mul: + result = set() + for m in f.args: + if m in {S.NegativeInfinity, S.ComplexInfinity, S.Infinity}: + result = set() + break + soln = _vsolve(m, symbol, **flags) + result.update(set(soln)) + result = [{symbol: v} for v in result] + if check: + # all solutions have been checked but now we must + # check that the solutions do not set denominators + # in any factor to zero + dens = flags.get('_denominators', _simple_dens(f, symbols)) + result = [s for s in result if + not any(checksol(den, s, **flags) for den in + dens)] + # set flags for quick exit at end; solutions for each + # factor were already checked and simplified + check = False + flags['simplify'] = False + + elif f.is_Piecewise: + result = set() + if any(e.is_zero for e, c in f.args): + f = f.simplify() # failure imminent w/o help + + cond = neg = True + for expr, cnd in f.args: + # the explicit condition for this expr is the current cond + # and none of the previous conditions + cond = And(neg, cnd) + neg = And(neg, ~cond) + + if expr.is_zero and cond.simplify() != False: + raise NotImplementedError(filldedent(''' + An expression is already zero when %s. + This means that in this *region* the solution + is zero but solve can only represent discrete, + not interval, solutions. If this is a spurious + interval it might be resolved with simplification + of the Piecewise conditions.''' % cond)) + candidates = _vsolve(expr, symbol, **flags) + + for candidate in candidates: + if candidate in result: + # an unconditional value was already there + continue + try: + v = cond.subs(symbol, candidate) + _eval_simplify = getattr(v, '_eval_simplify', None) + if _eval_simplify is not None: + # unconditionally take the simplification of v + v = _eval_simplify(ratio=2, measure=lambda x: 1) + except TypeError: + # incompatible type with condition(s) + continue + if v == False: + continue + if v == True: + result.add(candidate) + else: + result.add(Piecewise( + (candidate, v), + (S.NaN, True))) + # solutions already checked and simplified + # **************************************** + return [{symbol: r} for r in result] + else: + # first see if it really depends on symbol and whether there + # is only a linear solution + f_num, sol = solve_linear(f, symbols=symbols) + if f_num.is_zero or sol is S.NaN: + return [] + elif f_num.is_Symbol: + # no need to check but simplify if desired + if flags.get('simplify', True): + sol = simplify(sol) + return [{f_num: sol}] + + poly = None + # check for a single Add generator + if not f_num.is_Add: + add_args = [i for i in f_num.atoms(Add) + if symbol in i.free_symbols] + if len(add_args) == 1: + gen = add_args[0] + spart = gen.as_independent(symbol)[1].as_base_exp()[0] + if spart == symbol: + try: + poly = Poly(f_num, spart) + except PolynomialError: + pass + + result = False # no solution was obtained + msg = '' # there is no failure message + + # Poly is generally robust enough to convert anything to + # a polynomial and tell us the different generators that it + # contains, so we will inspect the generators identified by + # polys to figure out what to do. + + # try to identify a single generator that will allow us to solve this + # as a polynomial, followed (perhaps) by a change of variables if the + # generator is not a symbol + + try: + if poly is None: + poly = Poly(f_num) + if poly is None: + raise ValueError('could not convert %s to Poly' % f_num) + except GeneratorsNeeded: + simplified_f = simplify(f_num) + if simplified_f != f_num: + return _solve(simplified_f, symbol, **flags) + raise ValueError('expression appears to be a constant') + + gens = [g for g in poly.gens if g.has(symbol)] + + def _as_base_q(x): + """Return (b**e, q) for x = b**(p*e/q) where p/q is the leading + Rational of the exponent of x, e.g. exp(-2*x/3) -> (exp(x), 3) + """ + b, e = x.as_base_exp() + if e.is_Rational: + return b, e.q + if not e.is_Mul: + return x, 1 + c, ee = e.as_coeff_Mul() + if c.is_Rational and c is not S.One: # c could be a Float + return b**ee, c.q + return x, 1 + + if len(gens) > 1: + # If there is more than one generator, it could be that the + # generators have the same base but different powers, e.g. + # >>> Poly(exp(x) + 1/exp(x)) + # Poly(exp(-x) + exp(x), exp(-x), exp(x), domain='ZZ') + # + # If unrad was not disabled then there should be no rational + # exponents appearing as in + # >>> Poly(sqrt(x) + sqrt(sqrt(x))) + # Poly(sqrt(x) + x**(1/4), sqrt(x), x**(1/4), domain='ZZ') + + bases, qs = list(zip(*[_as_base_q(g) for g in gens])) + bases = set(bases) + + if len(bases) > 1 or not all(q == 1 for q in qs): + funcs = {b for b in bases if b.is_Function} + + trig = {_ for _ in funcs if + isinstance(_, TrigonometricFunction)} + other = funcs - trig + if not other and len(funcs.intersection(trig)) > 1: + newf = None + if f_num.is_Add and len(f_num.args) == 2: + # check for sin(x)**p = cos(x)**p + _args = f_num.args + t = a, b = [i.atoms(Function).intersection( + trig) for i in _args] + if all(len(i) == 1 for i in t): + a, b = [i.pop() for i in t] + if isinstance(a, cos): + a, b = b, a + _args = _args[::-1] + if isinstance(a, sin) and isinstance(b, cos + ) and a.args[0] == b.args[0]: + # sin(x) + cos(x) = 0 -> tan(x) + 1 = 0 + newf, _d = (TR2i(_args[0]/_args[1]) + 1 + ).as_numer_denom() + if not _d.is_Number: + newf = None + if newf is None: + newf = TR1(f_num).rewrite(tan) + if newf != f_num: + # don't check the rewritten form --check + # solutions in the un-rewritten form below + flags['check'] = False + result = _solve(newf, symbol, **flags) + flags['check'] = check + + # just a simple case - see if replacement of single function + # clears all symbol-dependent functions, e.g. + # log(x) - log(log(x) - 1) - 3 can be solved even though it has + # two generators. + + if result is False and funcs: + funcs = list(ordered(funcs)) # put shallowest function first + f1 = funcs[0] + t = Dummy('t') + # perform the substitution + ftry = f_num.subs(f1, t) + + # if no Functions left, we can proceed with usual solve + if not ftry.has(symbol): + cv_sols = _solve(ftry, t, **flags) + cv_inv = list(ordered(_vsolve(t - f1, symbol, **flags)))[0] + result = [{symbol: cv_inv.subs(sol)} for sol in cv_sols] + + if result is False: + msg = 'multiple generators %s' % gens + + else: + # e.g. case where gens are exp(x), exp(-x) + u = bases.pop() + t = Dummy('t') + inv = _vsolve(u - t, symbol, **flags) + if isinstance(u, (Pow, exp)): + # this will be resolved by factor in _tsolve but we might + # as well try a simple expansion here to get things in + # order so something like the following will work now without + # having to factor: + # + # >>> eq = (exp(I*(-x-2))+exp(I*(x+2))) + # >>> eq.subs(exp(x),y) # fails + # exp(I*(-x - 2)) + exp(I*(x + 2)) + # >>> eq.expand().subs(exp(x),y) # works + # y**I*exp(2*I) + y**(-I)*exp(-2*I) + def _expand(p): + b, e = p.as_base_exp() + e = expand_mul(e) + return expand_power_exp(b**e) + ftry = f_num.replace( + lambda w: w.is_Pow or isinstance(w, exp), + _expand).subs(u, t) + if not ftry.has(symbol): + soln = _solve(ftry, t, **flags) + result = [{symbol: i.subs(s)} for i in inv for s in soln] + + elif len(gens) == 1: + + # There is only one generator that we are interested in, but + # there may have been more than one generator identified by + # polys (e.g. for symbols other than the one we are interested + # in) so recast the poly in terms of our generator of interest. + # Also use composite=True with f_num since Poly won't update + # poly as documented in issue 8810. + + poly = Poly(f_num, gens[0], composite=True) + + # if we aren't on the tsolve-pass, use roots + if not flags.pop('tsolve', False): + soln = None + deg = poly.degree() + flags['tsolve'] = True + hints = ('cubics', 'quartics', 'quintics') + solvers = {h: flags.get(h) for h in hints} + soln = roots(poly, **solvers) + if sum(soln.values()) < deg: + # e.g. roots(32*x**5 + 400*x**4 + 2032*x**3 + + # 5000*x**2 + 6250*x + 3189) -> {} + # so all_roots is used and RootOf instances are + # returned *unless* the system is multivariate + # or high-order EX domain. + try: + soln = poly.all_roots() + except NotImplementedError: + if not flags.get('incomplete', True): + raise NotImplementedError( + filldedent(''' + Neither high-order multivariate polynomials + nor sorting of EX-domain polynomials is supported. + If you want to see any results, pass keyword incomplete=True to + solve; to see numerical values of roots + for univariate expressions, use nroots. + ''')) + else: + pass + else: + soln = list(soln.keys()) + + if soln is not None: + u = poly.gen + if u != symbol: + try: + t = Dummy('t') + inv = _vsolve(u - t, symbol, **flags) + soln = {i.subs(t, s) for i in inv for s in soln} + except NotImplementedError: + # perhaps _tsolve can handle f_num + soln = None + else: + check = False # only dens need to be checked + if soln is not None: + if len(soln) > 2: + # if the flag wasn't set then unset it since high-order + # results are quite long. Perhaps one could base this + # decision on a certain critical length of the + # roots. In addition, wester test M2 has an expression + # whose roots can be shown to be real with the + # unsimplified form of the solution whereas only one of + # the simplified forms appears to be real. + flags['simplify'] = flags.get('simplify', False) + if soln is not None: + result = [{symbol: v} for v in soln] + + # fallback if above fails + # ----------------------- + if result is False: + # try unrad + if flags.pop('_unrad', True): + try: + u = unrad(f_num, symbol) + except (ValueError, NotImplementedError): + u = False + if u: + eq, cov = u + if cov: + isym, ieq = cov + inv = _vsolve(ieq, symbol, **flags)[0] + rv = {inv.subs(xi) for xi in _solve(eq, isym, **flags)} + else: + try: + rv = set(_vsolve(eq, symbol, **flags)) + except NotImplementedError: + rv = None + if rv is not None: + result = [{symbol: v} for v in rv] + # if the flag wasn't set then unset it since unrad results + # can be quite long or of very high order + flags['simplify'] = flags.get('simplify', False) + else: + pass # for coverage + + # try _tsolve + if result is False: + flags.pop('tsolve', None) # allow tsolve to be used on next pass + try: + soln = _tsolve(f_num, symbol, **flags) + if soln is not None: + result = [{symbol: v} for v in soln] + except PolynomialError: + pass + # ----------- end of fallback ---------------------------- + + if result is False: + raise NotImplementedError('\n'.join([msg, not_impl_msg % f])) + + result = _remove_duplicate_solutions(result) + + if flags.get('simplify', True): + result = [{k: d[k].simplify() for k in d} for d in result] + # Simplification might reveal more duplicates + result = _remove_duplicate_solutions(result) + # we just simplified the solution so we now set the flag to + # False so the simplification doesn't happen again in checksol() + flags['simplify'] = False + + if checkdens: + # reject any result that makes any denom. affirmatively 0; + # if in doubt, keep it + dens = _simple_dens(f, symbols) + result = [r for r in result if + not any(checksol(d, r, **flags) + for d in dens)] + if check: + # keep only results if the check is not False + result = [r for r in result if + checksol(f_num, r, **flags) is not False] + return result + + +def _remove_duplicate_solutions(solutions: list[dict[Expr, Expr]] + ) -> list[dict[Expr, Expr]]: + """Remove duplicates from a list of dicts""" + solutions_set = set() + solutions_new = [] + + for sol in solutions: + solset = frozenset(sol.items()) + if solset not in solutions_set: + solutions_new.append(sol) + solutions_set.add(solset) + + return solutions_new + + +def _solve_system(exprs, symbols, **flags): + """return ``(linear, solution)`` where ``linear`` is True + if the system was linear, else False; ``solution`` + is a list of dictionaries giving solutions for the symbols + """ + if not exprs: + return False, [] + + if flags.pop('_split', True): + # Split the system into connected components + V = exprs + symsset = set(symbols) + exprsyms = {e: e.free_symbols & symsset for e in exprs} + E = [] + sym_indices = {sym: i for i, sym in enumerate(symbols)} + for n, e1 in enumerate(exprs): + for e2 in exprs[:n]: + # Equations are connected if they share a symbol + if exprsyms[e1] & exprsyms[e2]: + E.append((e1, e2)) + G = V, E + subexprs = connected_components(G) + if len(subexprs) > 1: + subsols = [] + linear = True + for subexpr in subexprs: + subsyms = set() + for e in subexpr: + subsyms |= exprsyms[e] + subsyms = sorted(subsyms, key = lambda x: sym_indices[x]) + flags['_split'] = False # skip split step + _linear, subsol = _solve_system(subexpr, subsyms, **flags) + if linear: + linear = linear and _linear + if not isinstance(subsol, list): + subsol = [subsol] + subsols.append(subsol) + # Full solution is cartesian product of subsystems + sols = [] + for soldicts in product(*subsols): + sols.append(dict(item for sd in soldicts + for item in sd.items())) + return linear, sols + + polys = [] + dens = set() + failed = [] + result = [] + solved_syms = [] + linear = True + manual = flags.get('manual', False) + checkdens = check = flags.get('check', True) + + for j, g in enumerate(exprs): + dens.update(_simple_dens(g, symbols)) + i, d = _invert(g, *symbols) + if d in symbols: + if linear: + linear = solve_linear(g, 0, [d])[0] == d + g = d - i + g = g.as_numer_denom()[0] + if manual: + failed.append(g) + continue + + poly = g.as_poly(*symbols, extension=True) + + if poly is not None: + polys.append(poly) + else: + failed.append(g) + + if polys: + if all(p.is_linear for p in polys): + n, m = len(polys), len(symbols) + matrix = zeros(n, m + 1) + + for i, poly in enumerate(polys): + for monom, coeff in poly.terms(): + try: + j = monom.index(1) + matrix[i, j] = coeff + except ValueError: + matrix[i, m] = -coeff + + # returns a dictionary ({symbols: values}) or None + if flags.pop('particular', False): + result = minsolve_linear_system(matrix, *symbols, **flags) + else: + result = solve_linear_system(matrix, *symbols, **flags) + result = [result] if result else [] + if failed: + if result: + solved_syms = list(result[0].keys()) # there is only one result dict + else: + solved_syms = [] + # linear doesn't change + else: + linear = False + if len(symbols) > len(polys): + + free = set().union(*[p.free_symbols for p in polys]) + free = list(ordered(free.intersection(symbols))) + got_s = set() + result = [] + for syms in subsets(free, min(len(free), len(polys))): + try: + # returns [], None or list of tuples + res = solve_poly_system(polys, *syms) + if res: + for r in set(res): + skip = False + for r1 in r: + if got_s and any(ss in r1.free_symbols + for ss in got_s): + # sol depends on previously + # solved symbols: discard it + skip = True + if not skip: + got_s.update(syms) + result.append(dict(list(zip(syms, r)))) + except NotImplementedError: + pass + if got_s: + solved_syms = list(got_s) + else: + failed.extend([g.as_expr() for g in polys]) + else: + try: + result = solve_poly_system(polys, *symbols) + if result: + solved_syms = symbols + result = [dict(list(zip(solved_syms, r))) for r in set(result)] + except NotImplementedError: + failed.extend([g.as_expr() for g in polys]) + solved_syms = [] + + # convert None or [] to [{}] + result = result or [{}] + + if failed: + linear = False + # For each failed equation, see if we can solve for one of the + # remaining symbols from that equation. If so, we update the + # solution set and continue with the next failed equation, + # repeating until we are done or we get an equation that can't + # be solved. + def _ok_syms(e, sort=False): + rv = e.free_symbols & legal + + # Solve first for symbols that have lower degree in the equation. + # Ideally we want to solve firstly for symbols that appear linearly + # with rational coefficients e.g. if e = x*y + z then we should + # solve for z first. + def key(sym): + ep = e.as_poly(sym) + if ep is None: + complexity = (S.Infinity, S.Infinity, S.Infinity) + else: + coeff_syms = ep.LC().free_symbols + complexity = (ep.degree(), len(coeff_syms & rv), len(coeff_syms)) + return complexity + (default_sort_key(sym),) + + if sort: + rv = sorted(rv, key=key) + return rv + + legal = set(symbols) # what we are interested in + # sort so equation with the fewest potential symbols is first + u = Dummy() # used in solution checking + for eq in ordered(failed, lambda _: len(_ok_syms(_))): + newresult = [] + bad_results = [] + hit = False + for r in result: + got_s = set() + # update eq with everything that is known so far + eq2 = eq.subs(r) + # if check is True then we see if it satisfies this + # equation, otherwise we just accept it + if check and r: + b = checksol(u, u, eq2, minimal=True) + if b is not None: + # this solution is sufficient to know whether + # it is valid or not so we either accept or + # reject it, then continue + if b: + newresult.append(r) + else: + bad_results.append(r) + continue + # search for a symbol amongst those available that + # can be solved for + ok_syms = _ok_syms(eq2, sort=True) + if not ok_syms: + if r: + newresult.append(r) + break # skip as it's independent of desired symbols + for s in ok_syms: + try: + soln = _vsolve(eq2, s, **flags) + except NotImplementedError: + continue + # put each solution in r and append the now-expanded + # result in the new result list; use copy since the + # solution for s is being added in-place + for sol in soln: + if got_s and any(ss in sol.free_symbols for ss in got_s): + # sol depends on previously solved symbols: discard it + continue + rnew = r.copy() + for k, v in r.items(): + rnew[k] = v.subs(s, sol) + # and add this new solution + rnew[s] = sol + # check that it is independent of previous solutions + iset = set(rnew.items()) + for i in newresult: + if len(i) < len(iset): + # update i with what is known + i_items_updated = {(k, v.xreplace(rnew)) for k, v in i.items()} + if not i_items_updated - iset: + # this is a superset of a known solution that + # is smaller + break + else: + # keep it + newresult.append(rnew) + hit = True + got_s.add(s) + if not hit: + raise NotImplementedError('could not solve %s' % eq2) + else: + result = newresult + for b in bad_results: + if b in result: + result.remove(b) + + if not result: + return False, [] + + # rely on linear/polynomial system solvers to simplify + # XXX the following tests show that the expressions + # returned are not the same as they would be if simplify + # were applied to this: + # sympy/solvers/ode/tests/test_systems/test__classify_linear_system + # sympy/solvers/tests/test_solvers/test_issue_4886 + # so the docs should be updated to reflect that or else + # the following should be `bool(failed) or not linear` + default_simplify = bool(failed) + if flags.get('simplify', default_simplify): + for r in result: + for k in r: + r[k] = simplify(r[k]) + flags['simplify'] = False # don't need to do so in checksol now + + if checkdens: + result = [r for r in result + if not any(checksol(d, r, **flags) for d in dens)] + + if check and not linear: + result = [r for r in result + if not any(checksol(e, r, **flags) is False for e in exprs)] + + result = [r for r in result if r] + return linear, result + + +def solve_linear(lhs, rhs=0, symbols=[], exclude=[]): + r""" + Return a tuple derived from ``f = lhs - rhs`` that is one of + the following: ``(0, 1)``, ``(0, 0)``, ``(symbol, solution)``, ``(n, d)``. + + Explanation + =========== + + ``(0, 1)`` meaning that ``f`` is independent of the symbols in *symbols* + that are not in *exclude*. + + ``(0, 0)`` meaning that there is no solution to the equation amongst the + symbols given. If the first element of the tuple is not zero, then the + function is guaranteed to be dependent on a symbol in *symbols*. + + ``(symbol, solution)`` where symbol appears linearly in the numerator of + ``f``, is in *symbols* (if given), and is not in *exclude* (if given). No + simplification is done to ``f`` other than a ``mul=True`` expansion, so the + solution will correspond strictly to a unique solution. + + ``(n, d)`` where ``n`` and ``d`` are the numerator and denominator of ``f`` + when the numerator was not linear in any symbol of interest; ``n`` will + never be a symbol unless a solution for that symbol was found (in which case + the second element is the solution, not the denominator). + + Examples + ======== + + >>> from sympy import cancel, Pow + + ``f`` is independent of the symbols in *symbols* that are not in + *exclude*: + + >>> from sympy import cos, sin, solve_linear + >>> from sympy.abc import x, y, z + >>> eq = y*cos(x)**2 + y*sin(x)**2 - y # = y*(1 - 1) = 0 + >>> solve_linear(eq) + (0, 1) + >>> eq = cos(x)**2 + sin(x)**2 # = 1 + >>> solve_linear(eq) + (0, 1) + >>> solve_linear(x, exclude=[x]) + (0, 1) + + The variable ``x`` appears as a linear variable in each of the + following: + + >>> solve_linear(x + y**2) + (x, -y**2) + >>> solve_linear(1/x - y**2) + (x, y**(-2)) + + When not linear in ``x`` or ``y`` then the numerator and denominator are + returned: + + >>> solve_linear(x**2/y**2 - 3) + (x**2 - 3*y**2, y**2) + + If the numerator of the expression is a symbol, then ``(0, 0)`` is + returned if the solution for that symbol would have set any + denominator to 0: + + >>> eq = 1/(1/x - 2) + >>> eq.as_numer_denom() + (x, 1 - 2*x) + >>> solve_linear(eq) + (0, 0) + + But automatic rewriting may cause a symbol in the denominator to + appear in the numerator so a solution will be returned: + + >>> (1/x)**-1 + x + >>> solve_linear((1/x)**-1) + (x, 0) + + Use an unevaluated expression to avoid this: + + >>> solve_linear(Pow(1/x, -1, evaluate=False)) + (0, 0) + + If ``x`` is allowed to cancel in the following expression, then it + appears to be linear in ``x``, but this sort of cancellation is not + done by ``solve_linear`` so the solution will always satisfy the + original expression without causing a division by zero error. + + >>> eq = x**2*(1/x - z**2/x) + >>> solve_linear(cancel(eq)) + (x, 0) + >>> solve_linear(eq) + (x**2*(1 - z**2), x) + + A list of symbols for which a solution is desired may be given: + + >>> solve_linear(x + y + z, symbols=[y]) + (y, -x - z) + + A list of symbols to ignore may also be given: + + >>> solve_linear(x + y + z, exclude=[x]) + (y, -x - z) + + (A solution for ``y`` is obtained because it is the first variable + from the canonically sorted list of symbols that had a linear + solution.) + + """ + if isinstance(lhs, Eq): + if rhs: + raise ValueError(filldedent(''' + If lhs is an Equality, rhs must be 0 but was %s''' % rhs)) + rhs = lhs.rhs + lhs = lhs.lhs + dens = None + eq = lhs - rhs + n, d = eq.as_numer_denom() + if not n: + return S.Zero, S.One + + free = n.free_symbols + if not symbols: + symbols = free + else: + bad = [s for s in symbols if not s.is_Symbol] + if bad: + if len(bad) == 1: + bad = bad[0] + if len(symbols) == 1: + eg = 'solve(%s, %s)' % (eq, symbols[0]) + else: + eg = 'solve(%s, *%s)' % (eq, list(symbols)) + raise ValueError(filldedent(''' + solve_linear only handles symbols, not %s. To isolate + non-symbols use solve, e.g. >>> %s <<<. + ''' % (bad, eg))) + symbols = free.intersection(symbols) + symbols = symbols.difference(exclude) + if not symbols: + return S.Zero, S.One + + # derivatives are easy to do but tricky to analyze to see if they + # are going to disallow a linear solution, so for simplicity we + # just evaluate the ones that have the symbols of interest + derivs = defaultdict(list) + for der in n.atoms(Derivative): + csym = der.free_symbols & symbols + for c in csym: + derivs[c].append(der) + + all_zero = True + for xi in sorted(symbols, key=default_sort_key): # canonical order + # if there are derivatives in this var, calculate them now + if isinstance(derivs[xi], list): + derivs[xi] = {der: der.doit() for der in derivs[xi]} + newn = n.subs(derivs[xi]) + dnewn_dxi = newn.diff(xi) + # dnewn_dxi can be nonzero if it survives differentation by any + # of its free symbols + free = dnewn_dxi.free_symbols + if dnewn_dxi and (not free or any(dnewn_dxi.diff(s) for s in free) or free == symbols): + all_zero = False + if dnewn_dxi is S.NaN: + break + if xi not in dnewn_dxi.free_symbols: + vi = -1/dnewn_dxi*(newn.subs(xi, 0)) + if dens is None: + dens = _simple_dens(eq, symbols) + if not any(checksol(di, {xi: vi}, minimal=True) is True + for di in dens): + # simplify any trivial integral + irep = [(i, i.doit()) for i in vi.atoms(Integral) if + i.function.is_number] + # do a slight bit of simplification + vi = expand_mul(vi.subs(irep)) + return xi, vi + if all_zero: + return S.Zero, S.One + if n.is_Symbol: # no solution for this symbol was found + return S.Zero, S.Zero + return n, d + + +def minsolve_linear_system(system, *symbols, **flags): + r""" + Find a particular solution to a linear system. + + Explanation + =========== + + In particular, try to find a solution with the minimal possible number + of non-zero variables using a naive algorithm with exponential complexity. + If ``quick=True``, a heuristic is used. + + """ + quick = flags.get('quick', False) + # Check if there are any non-zero solutions at all + s0 = solve_linear_system(system, *symbols, **flags) + if not s0 or all(v == 0 for v in s0.values()): + return s0 + if quick: + # We just solve the system and try to heuristically find a nice + # solution. + s = solve_linear_system(system, *symbols) + def update(determined, solution): + delete = [] + for k, v in solution.items(): + solution[k] = v.subs(determined) + if not solution[k].free_symbols: + delete.append(k) + determined[k] = solution[k] + for k in delete: + del solution[k] + determined = {} + update(determined, s) + while s: + # NOTE sort by default_sort_key to get deterministic result + k = max((k for k in s.values()), + key=lambda x: (len(x.free_symbols), default_sort_key(x))) + kfree = k.free_symbols + x = next(reversed(list(ordered(kfree)))) + if len(kfree) != 1: + determined[x] = S.Zero + else: + val = _vsolve(k, x, check=False)[0] + if not val and not any(v.subs(x, val) for v in s.values()): + determined[x] = S.One + else: + determined[x] = val + update(determined, s) + return determined + else: + # We try to select n variables which we want to be non-zero. + # All others will be assumed zero. We try to solve the modified system. + # If there is a non-trivial solution, just set the free variables to + # one. If we do this for increasing n, trying all combinations of + # variables, we will find an optimal solution. + # We speed up slightly by starting at one less than the number of + # variables the quick method manages. + N = len(symbols) + bestsol = minsolve_linear_system(system, *symbols, quick=True) + n0 = len([x for x in bestsol.values() if x != 0]) + for n in range(n0 - 1, 1, -1): + debugf('minsolve: %s', n) + thissol = None + for nonzeros in combinations(range(N), n): + subm = Matrix([system.col(i).T for i in nonzeros] + [system.col(-1).T]).T + s = solve_linear_system(subm, *[symbols[i] for i in nonzeros]) + if s and not all(v == 0 for v in s.values()): + subs = [(symbols[v], S.One) for v in nonzeros] + for k, v in s.items(): + s[k] = v.subs(subs) + for sym in symbols: + if sym not in s: + if symbols.index(sym) in nonzeros: + s[sym] = S.One + else: + s[sym] = S.Zero + thissol = s + break + if thissol is None: + break + bestsol = thissol + return bestsol + + +def solve_linear_system(system, *symbols, **flags): + r""" + Solve system of $N$ linear equations with $M$ variables, which means + both under- and overdetermined systems are supported. + + Explanation + =========== + + The possible number of solutions is zero, one, or infinite. Respectively, + this procedure will return None or a dictionary with solutions. In the + case of underdetermined systems, all arbitrary parameters are skipped. + This may cause a situation in which an empty dictionary is returned. + In that case, all symbols can be assigned arbitrary values. + + Input to this function is a $N\times M + 1$ matrix, which means it has + to be in augmented form. If you prefer to enter $N$ equations and $M$ + unknowns then use ``solve(Neqs, *Msymbols)`` instead. Note: a local + copy of the matrix is made by this routine so the matrix that is + passed will not be modified. + + The algorithm used here is fraction-free Gaussian elimination, + which results, after elimination, in an upper-triangular matrix. + Then solutions are found using back-substitution. This approach + is more efficient and compact than the Gauss-Jordan method. + + Examples + ======== + + >>> from sympy import Matrix, solve_linear_system + >>> from sympy.abc import x, y + + Solve the following system:: + + x + 4 y == 2 + -2 x + y == 14 + + >>> system = Matrix(( (1, 4, 2), (-2, 1, 14))) + >>> solve_linear_system(system, x, y) + {x: -6, y: 2} + + A degenerate system returns an empty dictionary: + + >>> system = Matrix(( (0,0,0), (0,0,0) )) + >>> solve_linear_system(system, x, y) + {} + + """ + assert system.shape[1] == len(symbols) + 1 + + # This is just a wrapper for solve_lin_sys + eqs = list(system * Matrix(symbols + (-1,))) + eqs, ring = sympy_eqs_to_ring(eqs, symbols) + sol = solve_lin_sys(eqs, ring, _raw=False) + if sol is not None: + sol = {sym:val for sym, val in sol.items() if sym != val} + return sol + + +def solve_undetermined_coeffs(equ, coeffs, *syms, **flags): + r""" + Solve a system of equations in $k$ parameters that is formed by + matching coefficients in variables ``coeffs`` that are on + factors dependent on the remaining variables (or those given + explicitly by ``syms``. + + Explanation + =========== + + The result of this function is a dictionary with symbolic values of those + parameters with respect to coefficients in $q$ -- empty if there + is no solution or coefficients do not appear in the equation -- else + None (if the system was not recognized). If there is more than one + solution, the solutions are passed as a list. The output can be modified using + the same semantics as for `solve` since the flags that are passed are sent + directly to `solve` so, for example the flag ``dict=True`` will always return a list + of solutions as dictionaries. + + This function accepts both Equality and Expr class instances. + The solving process is most efficient when symbols are specified + in addition to parameters to be determined, but an attempt to + determine them (if absent) will be made. If an expected solution is not + obtained (and symbols were not specified) try specifying them. + + Examples + ======== + + >>> from sympy import Eq, solve_undetermined_coeffs + >>> from sympy.abc import a, b, c, h, p, k, x, y + + >>> solve_undetermined_coeffs(Eq(a*x + a + b, x/2), [a, b], x) + {a: 1/2, b: -1/2} + >>> solve_undetermined_coeffs(a - 2, [a]) + {a: 2} + + The equation can be nonlinear in the symbols: + + >>> X, Y, Z = y, x**y, y*x**y + >>> eq = a*X + b*Y + c*Z - X - 2*Y - 3*Z + >>> coeffs = a, b, c + >>> syms = x, y + >>> solve_undetermined_coeffs(eq, coeffs, syms) + {a: 1, b: 2, c: 3} + + And the system can be nonlinear in coefficients, too, but if + there is only a single solution, it will be returned as a + dictionary: + + >>> eq = a*x**2 + b*x + c - ((x - h)**2 + 4*p*k)/4/p + >>> solve_undetermined_coeffs(eq, (h, p, k), x) + {h: -b/(2*a), k: (4*a*c - b**2)/(4*a), p: 1/(4*a)} + + Multiple solutions are always returned in a list: + + >>> solve_undetermined_coeffs(a**2*x + b - x, [a, b], x) + [{a: -1, b: 0}, {a: 1, b: 0}] + + Using flag ``dict=True`` (in keeping with semantics in :func:`~.solve`) + will force the result to always be a list with any solutions + as elements in that list. + + >>> solve_undetermined_coeffs(a*x - 2*x, [a], dict=True) + [{a: 2}] + """ + if not (coeffs and all(i.is_Symbol for i in coeffs)): + raise ValueError('must provide symbols for coeffs') + + if isinstance(equ, Eq): + eq = equ.lhs - equ.rhs + else: + eq = equ + + ceq = cancel(eq) + xeq = _mexpand(ceq.as_numer_denom()[0], recursive=True) + + free = xeq.free_symbols + coeffs = free & set(coeffs) + if not coeffs: + return ([], {}) if flags.get('set', None) else [] # solve(0, x) -> [] + + if not syms: + # e.g. A*exp(x) + B - (exp(x) + y) separated into parts that + # don't/do depend on coeffs gives + # -(exp(x) + y), A*exp(x) + B + # then see what symbols are common to both + # {x} = {x, A, B} - {x, y} + ind, dep = xeq.as_independent(*coeffs, as_Add=True) + dfree = dep.free_symbols + syms = dfree & ind.free_symbols + if not syms: + # but if the system looks like (a + b)*x + b - c + # then {} = {a, b, x} - c + # so calculate {x} = {a, b, x} - {a, b} + syms = dfree - set(coeffs) + if not syms: + syms = [Dummy()] + else: + if len(syms) == 1 and iterable(syms[0]): + syms = syms[0] + e, s, _ = recast_to_symbols([xeq], syms) + xeq = e[0] + syms = s + + # find the functional forms in which symbols appear + + gens = set(xeq.as_coefficients_dict(*syms).keys()) - {1} + cset = set(coeffs) + if any(g.has_xfree(cset) for g in gens): + return # a generator contained a coefficient symbol + + # make sure we are working with symbols for generators + + e, gens, _ = recast_to_symbols([xeq], list(gens)) + xeq = e[0] + + # collect coefficients in front of generators + + system = list(collect(xeq, gens, evaluate=False).values()) + + # get a solution + + soln = solve(system, coeffs, **flags) + + # unpack unless told otherwise if length is 1 + + settings = flags.get('dict', None) or flags.get('set', None) + if type(soln) is dict or settings or len(soln) != 1: + return soln + return soln[0] + + +def solve_linear_system_LU(matrix, syms): + """ + Solves the augmented matrix system using ``LUsolve`` and returns a + dictionary in which solutions are keyed to the symbols of *syms* as ordered. + + Explanation + =========== + + The matrix must be invertible. + + Examples + ======== + + >>> from sympy import Matrix, solve_linear_system_LU + >>> from sympy.abc import x, y, z + + >>> solve_linear_system_LU(Matrix([ + ... [1, 2, 0, 1], + ... [3, 2, 2, 1], + ... [2, 0, 0, 1]]), [x, y, z]) + {x: 1/2, y: 1/4, z: -1/2} + + See Also + ======== + + LUsolve + + """ + if matrix.rows != matrix.cols - 1: + raise ValueError("Rows should be equal to columns - 1") + A = matrix[:matrix.rows, :matrix.rows] + b = matrix[:, matrix.cols - 1:] + soln = A.LUsolve(b) + solutions = {} + for i in range(soln.rows): + solutions[syms[i]] = soln[i, 0] + return solutions + + +def det_perm(M): + """ + Return the determinant of *M* by using permutations to select factors. + + Explanation + =========== + + For sizes larger than 8 the number of permutations becomes prohibitively + large, or if there are no symbols in the matrix, it is better to use the + standard determinant routines (e.g., ``M.det()``.) + + See Also + ======== + + det_minor + det_quick + + """ + args = [] + s = True + n = M.rows + list_ = M.flat() + for perm in generate_bell(n): + fac = [] + idx = 0 + for j in perm: + fac.append(list_[idx + j]) + idx += n + term = Mul(*fac) # disaster with unevaluated Mul -- takes forever for n=7 + args.append(term if s else -term) + s = not s + return Add(*args) + + +def det_minor(M): + """ + Return the ``det(M)`` computed from minors without + introducing new nesting in products. + + See Also + ======== + + det_perm + det_quick + + """ + n = M.rows + if n == 2: + return M[0, 0]*M[1, 1] - M[1, 0]*M[0, 1] + else: + return sum((1, -1)[i % 2]*Add(*[M[0, i]*d for d in + Add.make_args(det_minor(M.minor_submatrix(0, i)))]) + if M[0, i] else S.Zero for i in range(n)) + + +def det_quick(M, method=None): + """ + Return ``det(M)`` assuming that either + there are lots of zeros or the size of the matrix + is small. If this assumption is not met, then the normal + Matrix.det function will be used with method = ``method``. + + See Also + ======== + + det_minor + det_perm + + """ + if any(i.has(Symbol) for i in M): + if M.rows < 8 and all(i.has(Symbol) for i in M): + return det_perm(M) + return det_minor(M) + else: + return M.det(method=method) if method else M.det() + + +def inv_quick(M): + """Return the inverse of ``M``, assuming that either + there are lots of zeros or the size of the matrix + is small. + """ + if not all(i.is_Number for i in M): + if not any(i.is_Number for i in M): + det = lambda _: det_perm(_) + else: + det = lambda _: det_minor(_) + else: + return M.inv() + n = M.rows + d = det(M) + if d == S.Zero: + raise NonInvertibleMatrixError("Matrix det == 0; not invertible") + ret = zeros(n) + s1 = -1 + for i in range(n): + s = s1 = -s1 + for j in range(n): + di = det(M.minor_submatrix(i, j)) + ret[j, i] = s*di/d + s = -s + return ret + + +# these are functions that have multiple inverse values per period +multi_inverses = { + sin: lambda x: (asin(x), S.Pi - asin(x)), + cos: lambda x: (acos(x), 2*S.Pi - acos(x)), +} + + +def _vsolve(e, s, **flags): + """return list of scalar values for the solution of e for symbol s""" + return [i[s] for i in _solve(e, s, **flags)] + + +def _tsolve(eq, sym, **flags): + """ + Helper for ``_solve`` that solves a transcendental equation with respect + to the given symbol. Various equations containing powers and logarithms, + can be solved. + + There is currently no guarantee that all solutions will be returned or + that a real solution will be favored over a complex one. + + Either a list of potential solutions will be returned or None will be + returned (in the case that no method was known to get a solution + for the equation). All other errors (like the inability to cast an + expression as a Poly) are unhandled. + + Examples + ======== + + >>> from sympy import log, ordered + >>> from sympy.solvers.solvers import _tsolve as tsolve + >>> from sympy.abc import x + + >>> list(ordered(tsolve(3**(2*x + 5) - 4, x))) + [-5/2 + log(2)/log(3), (-5*log(3)/2 + log(2) + I*pi)/log(3)] + + >>> tsolve(log(x) + 2*x, x) + [LambertW(2)/2] + + """ + if 'tsolve_saw' not in flags: + flags['tsolve_saw'] = [] + if eq in flags['tsolve_saw']: + return None + else: + flags['tsolve_saw'].append(eq) + + rhs, lhs = _invert(eq, sym) + + if lhs == sym: + return [rhs] + try: + if lhs.is_Add: + # it's time to try factoring; powdenest is used + # to try get powers in standard form for better factoring + f = factor(powdenest(lhs - rhs)) + if f.is_Mul: + return _vsolve(f, sym, **flags) + if rhs: + f = logcombine(lhs, force=flags.get('force', True)) + if f.count(log) != lhs.count(log): + if isinstance(f, log): + return _vsolve(f.args[0] - exp(rhs), sym, **flags) + return _tsolve(f - rhs, sym, **flags) + + elif lhs.is_Pow: + if lhs.exp.is_Integer: + if lhs - rhs != eq: + return _vsolve(lhs - rhs, sym, **flags) + + if sym not in lhs.exp.free_symbols: + return _vsolve(lhs.base - rhs**(1/lhs.exp), sym, **flags) + + # _tsolve calls this with Dummy before passing the actual number in. + if any(t.is_Dummy for t in rhs.free_symbols): + raise NotImplementedError # _tsolve will call here again... + + # a ** g(x) == 0 + if not rhs: + # f(x)**g(x) only has solutions where f(x) == 0 and g(x) != 0 at + # the same place + sol_base = _vsolve(lhs.base, sym, **flags) + return [s for s in sol_base if lhs.exp.subs(sym, s) != 0] # XXX use checksol here? + + # a ** g(x) == b + if not lhs.base.has(sym): + if lhs.base == 0: + return _vsolve(lhs.exp, sym, **flags) if rhs != 0 else [] + + # Gets most solutions... + if lhs.base == rhs.as_base_exp()[0]: + # handles case when bases are equal + sol = _vsolve(lhs.exp - rhs.as_base_exp()[1], sym, **flags) + else: + # handles cases when bases are not equal and exp + # may or may not be equal + f = exp(log(lhs.base)*lhs.exp) - exp(log(rhs)) + sol = _vsolve(f, sym, **flags) + + # Check for duplicate solutions + def equal(expr1, expr2): + _ = Dummy() + eq = checksol(expr1 - _, _, expr2) + if eq is None: + if nsimplify(expr1) != nsimplify(expr2): + return False + # they might be coincidentally the same + # so check more rigorously + eq = expr1.equals(expr2) # XXX expensive but necessary? + return eq + + # Guess a rational exponent + e_rat = nsimplify(log(abs(rhs))/log(abs(lhs.base))) + e_rat = simplify(posify(e_rat)[0]) + n, d = fraction(e_rat) + if expand(lhs.base**n - rhs**d) == 0: + sol = [s for s in sol if not equal(lhs.exp.subs(sym, s), e_rat)] + sol.extend(_vsolve(lhs.exp - e_rat, sym, **flags)) + + return list(set(sol)) + + # f(x) ** g(x) == c + else: + sol = [] + logform = lhs.exp*log(lhs.base) - log(rhs) + if logform != lhs - rhs: + try: + sol.extend(_vsolve(logform, sym, **flags)) + except NotImplementedError: + pass + + # Collect possible solutions and check with substitution later. + check = [] + if rhs == 1: + # f(x) ** g(x) = 1 -- g(x)=0 or f(x)=+-1 + check.extend(_vsolve(lhs.exp, sym, **flags)) + check.extend(_vsolve(lhs.base - 1, sym, **flags)) + check.extend(_vsolve(lhs.base + 1, sym, **flags)) + elif rhs.is_Rational: + for d in (i for i in divisors(abs(rhs.p)) if i != 1): + e, t = integer_log(rhs.p, d) + if not t: + continue # rhs.p != d**b + for s in divisors(abs(rhs.q)): + if s**e== rhs.q: + r = Rational(d, s) + check.extend(_vsolve(lhs.base - r, sym, **flags)) + check.extend(_vsolve(lhs.base + r, sym, **flags)) + check.extend(_vsolve(lhs.exp - e, sym, **flags)) + elif rhs.is_irrational: + b_l, e_l = lhs.base.as_base_exp() + n, d = (e_l*lhs.exp).as_numer_denom() + b, e = sqrtdenest(rhs).as_base_exp() + check = [sqrtdenest(i) for i in (_vsolve(lhs.base - b, sym, **flags))] + check.extend([sqrtdenest(i) for i in (_vsolve(lhs.exp - e, sym, **flags))]) + if e_l*d != 1: + check.extend(_vsolve(b_l**n - rhs**(e_l*d), sym, **flags)) + for s in check: + ok = checksol(eq, sym, s) + if ok is None: + ok = eq.subs(sym, s).equals(0) + if ok: + sol.append(s) + return list(set(sol)) + + elif lhs.is_Function and len(lhs.args) == 1: + if lhs.func in multi_inverses: + # sin(x) = 1/3 -> x - asin(1/3) & x - (pi - asin(1/3)) + soln = [] + for i in multi_inverses[type(lhs)](rhs): + soln.extend(_vsolve(lhs.args[0] - i, sym, **flags)) + return list(set(soln)) + elif lhs.func == LambertW: + return _vsolve(lhs.args[0] - rhs*exp(rhs), sym, **flags) + + rewrite = lhs.rewrite(exp) + rewrite = rebuild(rewrite) # avoid rewrites involving evaluate=False + if rewrite != lhs: + return _vsolve(rewrite - rhs, sym, **flags) + except NotImplementedError: + pass + + # maybe it is a lambert pattern + if flags.pop('bivariate', True): + # lambert forms may need some help being recognized, e.g. changing + # 2**(3*x) + x**3*log(2)**3 + 3*x**2*log(2)**2 + 3*x*log(2) + 1 + # to 2**(3*x) + (x*log(2) + 1)**3 + + # make generator in log have exponent of 1 + logs = eq.atoms(log) + spow = min( + {i.exp for j in logs for i in j.atoms(Pow) + if i.base == sym} or {1}) + if spow != 1: + p = sym**spow + u = Dummy('bivariate-cov') + ueq = eq.subs(p, u) + if not ueq.has_free(sym): + sol = _vsolve(ueq, u, **flags) + inv = _vsolve(p - u, sym) + return [i.subs(u, s) for i in inv for s in sol] + + g = _filtered_gens(eq.as_poly(), sym) + up_or_log = set() + for gi in g: + if isinstance(gi, (exp, log)) or (gi.is_Pow and gi.base == S.Exp1): + up_or_log.add(gi) + elif gi.is_Pow: + gisimp = powdenest(expand_power_exp(gi)) + if gisimp.is_Pow and sym in gisimp.exp.free_symbols: + up_or_log.add(gi) + eq_down = expand_log(expand_power_exp(eq)).subs( + dict(list(zip(up_or_log, [0]*len(up_or_log))))) + eq = expand_power_exp(factor(eq_down, deep=True) + (eq - eq_down)) + rhs, lhs = _invert(eq, sym) + if lhs.has(sym): + try: + poly = lhs.as_poly() + g = _filtered_gens(poly, sym) + _eq = lhs - rhs + sols = _solve_lambert(_eq, sym, g) + # use a simplified form if it satisfies eq + # and has fewer operations + for n, s in enumerate(sols): + ns = nsimplify(s) + if ns != s and ns.count_ops() <= s.count_ops(): + ok = checksol(_eq, sym, ns) + if ok is None: + ok = _eq.subs(sym, ns).equals(0) + if ok: + sols[n] = ns + return sols + except NotImplementedError: + # maybe it's a convoluted function + if len(g) == 2: + try: + gpu = bivariate_type(lhs - rhs, *g) + if gpu is None: + raise NotImplementedError + g, p, u = gpu + flags['bivariate'] = False + inversion = _tsolve(g - u, sym, **flags) + if inversion: + sol = _vsolve(p, u, **flags) + return list({i.subs(u, s) + for i in inversion for s in sol}) + except NotImplementedError: + pass + else: + pass + + if flags.pop('force', True): + flags['force'] = False + pos, reps = posify(lhs - rhs) + if rhs == S.ComplexInfinity: + return [] + for u, s in reps.items(): + if s == sym: + break + else: + u = sym + if pos.has(u): + try: + soln = _vsolve(pos, u, **flags) + return [s.subs(reps) for s in soln] + except NotImplementedError: + pass + else: + pass # here for coverage + + return # here for coverage + + +# TODO: option for calculating J numerically + +@conserve_mpmath_dps +def nsolve(*args, dict=False, **kwargs): + r""" + Solve a nonlinear equation system numerically: ``nsolve(f, [args,] x0, + modules=['mpmath'], **kwargs)``. + + Explanation + =========== + + ``f`` is a vector function of symbolic expressions representing the system. + *args* are the variables. If there is only one variable, this argument can + be omitted. ``x0`` is a starting vector close to a solution. + + Use the modules keyword to specify which modules should be used to + evaluate the function and the Jacobian matrix. Make sure to use a module + that supports matrices. For more information on the syntax, please see the + docstring of ``lambdify``. + + If the keyword arguments contain ``dict=True`` (default is False) ``nsolve`` + will return a list (perhaps empty) of solution mappings. This might be + especially useful if you want to use ``nsolve`` as a fallback to solve since + using the dict argument for both methods produces return values of + consistent type structure. Please note: to keep this consistent with + ``solve``, the solution will be returned in a list even though ``nsolve`` + (currently at least) only finds one solution at a time. + + Overdetermined systems are supported. + + Examples + ======== + + >>> from sympy import Symbol, nsolve + >>> import mpmath + >>> mpmath.mp.dps = 15 + >>> x1 = Symbol('x1') + >>> x2 = Symbol('x2') + >>> f1 = 3 * x1**2 - 2 * x2**2 - 1 + >>> f2 = x1**2 - 2 * x1 + x2**2 + 2 * x2 - 8 + >>> print(nsolve((f1, f2), (x1, x2), (-1, 1))) + Matrix([[-1.19287309935246], [1.27844411169911]]) + + For one-dimensional functions the syntax is simplified: + + >>> from sympy import sin, nsolve + >>> from sympy.abc import x + >>> nsolve(sin(x), x, 2) + 3.14159265358979 + >>> nsolve(sin(x), 2) + 3.14159265358979 + + To solve with higher precision than the default, use the prec argument: + + >>> from sympy import cos + >>> nsolve(cos(x) - x, 1) + 0.739085133215161 + >>> nsolve(cos(x) - x, 1, prec=50) + 0.73908513321516064165531208767387340401341175890076 + >>> cos(_) + 0.73908513321516064165531208767387340401341175890076 + + To solve for complex roots of real functions, a nonreal initial point + must be specified: + + >>> from sympy import I + >>> nsolve(x**2 + 2, I) + 1.4142135623731*I + + ``mpmath.findroot`` is used and you can find their more extensive + documentation, especially concerning keyword parameters and + available solvers. Note, however, that functions which are very + steep near the root, the verification of the solution may fail. In + this case you should use the flag ``verify=False`` and + independently verify the solution. + + >>> from sympy import cos, cosh + >>> f = cos(x)*cosh(x) - 1 + >>> nsolve(f, 3.14*100) + Traceback (most recent call last): + ... + ValueError: Could not find root within given tolerance. (1.39267e+230 > 2.1684e-19) + >>> ans = nsolve(f, 3.14*100, verify=False); ans + 312.588469032184 + >>> f.subs(x, ans).n(2) + 2.1e+121 + >>> (f/f.diff(x)).subs(x, ans).n(2) + 7.4e-15 + + One might safely skip the verification if bounds of the root are known + and a bisection method is used: + + >>> bounds = lambda i: (3.14*i, 3.14*(i + 1)) + >>> nsolve(f, bounds(100), solver='bisect', verify=False) + 315.730061685774 + + Alternatively, a function may be better behaved when the + denominator is ignored. Since this is not always the case, however, + the decision of what function to use is left to the discretion of + the user. + + >>> eq = x**2/(1 - x)/(1 - 2*x)**2 - 100 + >>> nsolve(eq, 0.46) + Traceback (most recent call last): + ... + ValueError: Could not find root within given tolerance. (10000 > 2.1684e-19) + Try another starting point or tweak arguments. + >>> nsolve(eq.as_numer_denom()[0], 0.46) + 0.46792545969349058 + + """ + # there are several other SymPy functions that use method= so + # guard against that here + if 'method' in kwargs: + raise ValueError(filldedent(''' + Keyword "method" should not be used in this context. When using + some mpmath solvers directly, the keyword "method" is + used, but when using nsolve (and findroot) the keyword to use is + "solver".''')) + + if 'prec' in kwargs: + import mpmath + mpmath.mp.dps = kwargs.pop('prec') + + # keyword argument to return result as a dictionary + as_dict = dict + from builtins import dict # to unhide the builtin + + # interpret arguments + if len(args) == 3: + f = args[0] + fargs = args[1] + x0 = args[2] + if iterable(fargs) and iterable(x0): + if len(x0) != len(fargs): + raise TypeError('nsolve expected exactly %i guess vectors, got %i' + % (len(fargs), len(x0))) + elif len(args) == 2: + f = args[0] + fargs = None + x0 = args[1] + if iterable(f): + raise TypeError('nsolve expected 3 arguments, got 2') + elif len(args) < 2: + raise TypeError('nsolve expected at least 2 arguments, got %i' + % len(args)) + else: + raise TypeError('nsolve expected at most 3 arguments, got %i' + % len(args)) + modules = kwargs.get('modules', ['mpmath']) + if iterable(f): + f = list(f) + for i, fi in enumerate(f): + if isinstance(fi, Eq): + f[i] = fi.lhs - fi.rhs + f = Matrix(f).T + if iterable(x0): + x0 = list(x0) + if not isinstance(f, Matrix): + # assume it's a SymPy expression + if isinstance(f, Eq): + f = f.lhs - f.rhs + elif f.is_Relational: + raise TypeError('nsolve cannot accept inequalities') + syms = f.free_symbols + if fargs is None: + fargs = syms.copy().pop() + if not (len(syms) == 1 and (fargs in syms or fargs[0] in syms)): + raise ValueError(filldedent(''' + expected a one-dimensional and numerical function''')) + + # the function is much better behaved if there is no denominator + # but sending the numerator is left to the user since sometimes + # the function is better behaved when the denominator is present + # e.g., issue 11768 + + f = lambdify(fargs, f, modules) + x = sympify(findroot(f, x0, **kwargs)) + if as_dict: + return [{fargs: x}] + return x + + if len(fargs) > f.cols: + raise NotImplementedError(filldedent(''' + need at least as many equations as variables''')) + verbose = kwargs.get('verbose', False) + if verbose: + print('f(x):') + print(f) + # derive Jacobian + J = f.jacobian(fargs) + if verbose: + print('J(x):') + print(J) + # create functions + f = lambdify(fargs, f.T, modules) + J = lambdify(fargs, J, modules) + # solve the system numerically + x = findroot(f, x0, J=J, **kwargs) + if as_dict: + return [dict(zip(fargs, [sympify(xi) for xi in x]))] + return Matrix(x) + + +def _invert(eq, *symbols, **kwargs): + """ + Return tuple (i, d) where ``i`` is independent of *symbols* and ``d`` + contains symbols. + + Explanation + =========== + + ``i`` and ``d`` are obtained after recursively using algebraic inversion + until an uninvertible ``d`` remains. If there are no free symbols then + ``d`` will be zero. Some (but not necessarily all) solutions to the + expression ``i - d`` will be related to the solutions of the original + expression. + + Examples + ======== + + >>> from sympy.solvers.solvers import _invert as invert + >>> from sympy import sqrt, cos + >>> from sympy.abc import x, y + >>> invert(x - 3) + (3, x) + >>> invert(3) + (3, 0) + >>> invert(2*cos(x) - 1) + (1/2, cos(x)) + >>> invert(sqrt(x) - 3) + (3, sqrt(x)) + >>> invert(sqrt(x) + y, x) + (-y, sqrt(x)) + >>> invert(sqrt(x) + y, y) + (-sqrt(x), y) + >>> invert(sqrt(x) + y, x, y) + (0, sqrt(x) + y) + + If there is more than one symbol in a power's base and the exponent + is not an Integer, then the principal root will be used for the + inversion: + + >>> invert(sqrt(x + y) - 2) + (4, x + y) + >>> invert(sqrt(x + y) + 2) # note +2 instead of -2 + (4, x + y) + + If the exponent is an Integer, setting ``integer_power`` to True + will force the principal root to be selected: + + >>> invert(x**2 - 4, integer_power=True) + (2, x) + + """ + eq = sympify(eq) + if eq.args: + # make sure we are working with flat eq + eq = eq.func(*eq.args) + free = eq.free_symbols + if not symbols: + symbols = free + if not free & set(symbols): + return eq, S.Zero + + dointpow = bool(kwargs.get('integer_power', False)) + + lhs = eq + rhs = S.Zero + while True: + was = lhs + while True: + indep, dep = lhs.as_independent(*symbols) + + # dep + indep == rhs + if lhs.is_Add: + # this indicates we have done it all + if indep.is_zero: + break + + lhs = dep + rhs -= indep + + # dep * indep == rhs + else: + # this indicates we have done it all + if indep is S.One: + break + + lhs = dep + rhs /= indep + + # collect like-terms in symbols + if lhs.is_Add: + terms = {} + for a in lhs.args: + i, d = a.as_independent(*symbols) + terms.setdefault(d, []).append(i) + if any(len(v) > 1 for v in terms.values()): + args = [] + for d, i in terms.items(): + if len(i) > 1: + args.append(Add(*i)*d) + else: + args.append(i[0]*d) + lhs = Add(*args) + + # if it's a two-term Add with rhs = 0 and two powers we can get the + # dependent terms together, e.g. 3*f(x) + 2*g(x) -> f(x)/g(x) = -2/3 + if lhs.is_Add and not rhs and len(lhs.args) == 2 and \ + not lhs.is_polynomial(*symbols): + a, b = ordered(lhs.args) + ai, ad = a.as_independent(*symbols) + bi, bd = b.as_independent(*symbols) + if any(_ispow(i) for i in (ad, bd)): + a_base, a_exp = ad.as_base_exp() + b_base, b_exp = bd.as_base_exp() + if a_base == b_base and a_exp.extract_additively(b_exp) is None: + # a = -b and exponents do not have canceling terms/factors + # e.g. if exponents were 3*x and x then the ratio would have + # an exponent of 2*x: one of the roots would be lost + rat = powsimp(powdenest(ad/bd)) + lhs = rat + rhs = -bi/ai + else: + rat = ad/bd + _lhs = powsimp(ad/bd) + if _lhs != rat: + lhs = _lhs + rhs = -bi/ai + elif ai == -bi: + if isinstance(ad, Function) and ad.func == bd.func: + if len(ad.args) == len(bd.args) == 1: + lhs = ad.args[0] - bd.args[0] + elif len(ad.args) == len(bd.args): + # should be able to solve + # f(x, y) - f(2 - x, 0) == 0 -> x == 1 + raise NotImplementedError( + 'equal function with more than 1 argument') + else: + raise ValueError( + 'function with different numbers of args') + + elif lhs.is_Mul and any(_ispow(a) for a in lhs.args): + lhs = powsimp(powdenest(lhs)) + + if lhs.is_Function: + if hasattr(lhs, 'inverse') and lhs.inverse() is not None and len(lhs.args) == 1: + # -1 + # f(x) = g -> x = f (g) + # + # /!\ inverse should not be defined if there are multiple values + # for the function -- these are handled in _tsolve + # + rhs = lhs.inverse()(rhs) + lhs = lhs.args[0] + elif isinstance(lhs, atan2): + y, x = lhs.args + lhs = 2*atan(y/(sqrt(x**2 + y**2) + x)) + elif lhs.func == rhs.func: + if len(lhs.args) == len(rhs.args) == 1: + lhs = lhs.args[0] + rhs = rhs.args[0] + elif len(lhs.args) == len(rhs.args): + # should be able to solve + # f(x, y) == f(2, 3) -> x == 2 + # f(x, x + y) == f(2, 3) -> x == 2 + raise NotImplementedError( + 'equal function with more than 1 argument') + else: + raise ValueError( + 'function with different numbers of args') + + + if rhs and lhs.is_Pow and lhs.exp.is_Integer and lhs.exp < 0: + lhs = 1/lhs + rhs = 1/rhs + + # base**a = b -> base = b**(1/a) if + # a is an Integer and dointpow=True (this gives real branch of root) + # a is not an Integer and the equation is multivariate and the + # base has more than 1 symbol in it + # The rationale for this is that right now the multi-system solvers + # doesn't try to resolve generators to see, for example, if the whole + # system is written in terms of sqrt(x + y) so it will just fail, so we + # do that step here. + if lhs.is_Pow and ( + lhs.exp.is_Integer and dointpow or not lhs.exp.is_Integer and + len(symbols) > 1 and len(lhs.base.free_symbols & set(symbols)) > 1): + rhs = rhs**(1/lhs.exp) + lhs = lhs.base + + if lhs == was: + break + return rhs, lhs + + +def unrad(eq, *syms, **flags): + """ + Remove radicals with symbolic arguments and return (eq, cov), + None, or raise an error. + + Explanation + =========== + + None is returned if there are no radicals to remove. + + NotImplementedError is raised if there are radicals and they cannot be + removed or if the relationship between the original symbols and the + change of variable needed to rewrite the system as a polynomial cannot + be solved. + + Otherwise the tuple, ``(eq, cov)``, is returned where: + + *eq*, ``cov`` + *eq* is an equation without radicals (in the symbol(s) of + interest) whose solutions are a superset of the solutions to the + original expression. *eq* might be rewritten in terms of a new + variable; the relationship to the original variables is given by + ``cov`` which is a list containing ``v`` and ``v**p - b`` where + ``p`` is the power needed to clear the radical and ``b`` is the + radical now expressed as a polynomial in the symbols of interest. + For example, for sqrt(2 - x) the tuple would be + ``(c, c**2 - 2 + x)``. The solutions of *eq* will contain + solutions to the original equation (if there are any). + + *syms* + An iterable of symbols which, if provided, will limit the focus of + radical removal: only radicals with one or more of the symbols of + interest will be cleared. All free symbols are used if *syms* is not + set. + + *flags* are used internally for communication during recursive calls. + Two options are also recognized: + + ``take``, when defined, is interpreted as a single-argument function + that returns True if a given Pow should be handled. + + Radicals can be removed from an expression if: + + * All bases of the radicals are the same; a change of variables is + done in this case. + * If all radicals appear in one term of the expression. + * There are only four terms with sqrt() factors or there are less than + four terms having sqrt() factors. + * There are only two terms with radicals. + + Examples + ======== + + >>> from sympy.solvers.solvers import unrad + >>> from sympy.abc import x + >>> from sympy import sqrt, Rational, root + + >>> unrad(sqrt(x)*x**Rational(1, 3) + 2) + (x**5 - 64, []) + >>> unrad(sqrt(x) + root(x + 1, 3)) + (-x**3 + x**2 + 2*x + 1, []) + >>> eq = sqrt(x) + root(x, 3) - 2 + >>> unrad(eq) + (_p**3 + _p**2 - 2, [_p, _p**6 - x]) + + """ + + uflags = {"check": False, "simplify": False} + + def _cov(p, e): + if cov: + # XXX - uncovered + oldp, olde = cov + if Poly(e, p).degree(p) in (1, 2): + cov[:] = [p, olde.subs(oldp, _vsolve(e, p, **uflags)[0])] + else: + raise NotImplementedError + else: + cov[:] = [p, e] + + def _canonical(eq, cov): + if cov: + # change symbol to vanilla so no solutions are eliminated + p, e = cov + rep = {p: Dummy(p.name)} + eq = eq.xreplace(rep) + cov = [p.xreplace(rep), e.xreplace(rep)] + + # remove constants and powers of factors since these don't change + # the location of the root; XXX should factor or factor_terms be used? + eq = factor_terms(_mexpand(eq.as_numer_denom()[0], recursive=True), clear=True) + if eq.is_Mul: + args = [] + for f in eq.args: + if f.is_number: + continue + if f.is_Pow: + args.append(f.base) + else: + args.append(f) + eq = Mul(*args) # leave as Mul for more efficient solving + + # make the sign canonical + margs = list(Mul.make_args(eq)) + changed = False + for i, m in enumerate(margs): + if m.could_extract_minus_sign(): + margs[i] = -m + changed = True + if changed: + eq = Mul(*margs, evaluate=False) + + return eq, cov + + def _Q(pow): + # return leading Rational of denominator of Pow's exponent + c = pow.as_base_exp()[1].as_coeff_Mul()[0] + if not c.is_Rational: + return S.One + return c.q + + # define the _take method that will determine whether a term is of interest + def _take(d): + # return True if coefficient of any factor's exponent's den is not 1 + for pow in Mul.make_args(d): + if not pow.is_Pow: + continue + if _Q(pow) == 1: + continue + if pow.free_symbols & syms: + return True + return False + _take = flags.setdefault('_take', _take) + + if isinstance(eq, Eq): + eq = eq.lhs - eq.rhs # XXX legacy Eq as Eqn support + elif not isinstance(eq, Expr): + return + + cov, nwas, rpt = [flags.setdefault(k, v) for k, v in + sorted({"cov": [], "n": None, "rpt": 0}.items())] + + # preconditioning + eq = powdenest(factor_terms(eq, radical=True, clear=True)) + eq = eq.as_numer_denom()[0] + eq = _mexpand(eq, recursive=True) + if eq.is_number: + return + + # see if there are radicals in symbols of interest + syms = set(syms) or eq.free_symbols # _take uses this + poly = eq.as_poly() + gens = [g for g in poly.gens if _take(g)] + if not gens: + return + + # recast poly in terms of eigen-gens + poly = eq.as_poly(*gens) + + # not a polynomial e.g. 1 + sqrt(x)*exp(sqrt(x)) with gen sqrt(x) + if poly is None: + return + + # - an exponent has a symbol of interest (don't handle) + if any(g.exp.has(*syms) for g in gens): + return + + def _rads_bases_lcm(poly): + # if all the bases are the same or all the radicals are in one + # term, `lcm` will be the lcm of the denominators of the + # exponents of the radicals + lcm = 1 + rads = set() + bases = set() + for g in poly.gens: + q = _Q(g) + if q != 1: + rads.add(g) + lcm = ilcm(lcm, q) + bases.add(g.base) + return rads, bases, lcm + rads, bases, lcm = _rads_bases_lcm(poly) + + covsym = Dummy('p', nonnegative=True) + + # only keep in syms symbols that actually appear in radicals; + # and update gens + newsyms = set() + for r in rads: + newsyms.update(syms & r.free_symbols) + if newsyms != syms: + syms = newsyms + # get terms together that have common generators + drad = dict(zip(rads, range(len(rads)))) + rterms = {(): []} + args = Add.make_args(poly.as_expr()) + for t in args: + if _take(t): + common = set(t.as_poly().gens).intersection(rads) + key = tuple(sorted([drad[i] for i in common])) + else: + key = () + rterms.setdefault(key, []).append(t) + others = Add(*rterms.pop(())) + rterms = [Add(*rterms[k]) for k in rterms.keys()] + + # the output will depend on the order terms are processed, so + # make it canonical quickly + rterms = list(reversed(list(ordered(rterms)))) + + ok = False # we don't have a solution yet + depth = sqrt_depth(eq) + + if len(rterms) == 1 and not (rterms[0].is_Add and lcm > 2): + eq = rterms[0]**lcm - ((-others)**lcm) + ok = True + else: + if len(rterms) == 1 and rterms[0].is_Add: + rterms = list(rterms[0].args) + if len(bases) == 1: + b = bases.pop() + if len(syms) > 1: + x = b.free_symbols + else: + x = syms + x = list(ordered(x))[0] + try: + inv = _vsolve(covsym**lcm - b, x, **uflags) + if not inv: + raise NotImplementedError + eq = poly.as_expr().subs(b, covsym**lcm).subs(x, inv[0]) + _cov(covsym, covsym**lcm - b) + return _canonical(eq, cov) + except NotImplementedError: + pass + + if len(rterms) == 2: + if not others: + eq = rterms[0]**lcm - (-rterms[1])**lcm + ok = True + elif not log(lcm, 2).is_Integer: + # the lcm-is-power-of-two case is handled below + r0, r1 = rterms + if flags.get('_reverse', False): + r1, r0 = r0, r1 + i0 = _rads0, _bases0, lcm0 = _rads_bases_lcm(r0.as_poly()) + i1 = _rads1, _bases1, lcm1 = _rads_bases_lcm(r1.as_poly()) + for reverse in range(2): + if reverse: + i0, i1 = i1, i0 + r0, r1 = r1, r0 + _rads1, _, lcm1 = i1 + _rads1 = Mul(*_rads1) + t1 = _rads1**lcm1 + c = covsym**lcm1 - t1 + for x in syms: + try: + sol = _vsolve(c, x, **uflags) + if not sol: + raise NotImplementedError + neweq = r0.subs(x, sol[0]) + covsym*r1/_rads1 + \ + others + tmp = unrad(neweq, covsym) + if tmp: + eq, newcov = tmp + if newcov: + newp, newc = newcov + _cov(newp, c.subs(covsym, + _vsolve(newc, covsym, **uflags)[0])) + else: + _cov(covsym, c) + else: + eq = neweq + _cov(covsym, c) + ok = True + break + except NotImplementedError: + if reverse: + raise NotImplementedError( + 'no successful change of variable found') + else: + pass + if ok: + break + elif len(rterms) == 3: + # two cube roots and another with order less than 5 + # (so an analytical solution can be found) or a base + # that matches one of the cube root bases + info = [_rads_bases_lcm(i.as_poly()) for i in rterms] + RAD = 0 + BASES = 1 + LCM = 2 + if info[0][LCM] != 3: + info.append(info.pop(0)) + rterms.append(rterms.pop(0)) + elif info[1][LCM] != 3: + info.append(info.pop(1)) + rterms.append(rterms.pop(1)) + if info[0][LCM] == info[1][LCM] == 3: + if info[1][BASES] != info[2][BASES]: + info[0], info[1] = info[1], info[0] + rterms[0], rterms[1] = rterms[1], rterms[0] + if info[1][BASES] == info[2][BASES]: + eq = rterms[0]**3 + (rterms[1] + rterms[2] + others)**3 + ok = True + elif info[2][LCM] < 5: + # a*root(A, 3) + b*root(B, 3) + others = c + a, b, c, d, A, B = [Dummy(i) for i in 'abcdAB'] + # zz represents the unraded expression into which the + # specifics for this case are substituted + zz = (c - d)*(A**3*a**9 + 3*A**2*B*a**6*b**3 - + 3*A**2*a**6*c**3 + 9*A**2*a**6*c**2*d - 9*A**2*a**6*c*d**2 + + 3*A**2*a**6*d**3 + 3*A*B**2*a**3*b**6 + 21*A*B*a**3*b**3*c**3 - + 63*A*B*a**3*b**3*c**2*d + 63*A*B*a**3*b**3*c*d**2 - + 21*A*B*a**3*b**3*d**3 + 3*A*a**3*c**6 - 18*A*a**3*c**5*d + + 45*A*a**3*c**4*d**2 - 60*A*a**3*c**3*d**3 + 45*A*a**3*c**2*d**4 - + 18*A*a**3*c*d**5 + 3*A*a**3*d**6 + B**3*b**9 - 3*B**2*b**6*c**3 + + 9*B**2*b**6*c**2*d - 9*B**2*b**6*c*d**2 + 3*B**2*b**6*d**3 + + 3*B*b**3*c**6 - 18*B*b**3*c**5*d + 45*B*b**3*c**4*d**2 - + 60*B*b**3*c**3*d**3 + 45*B*b**3*c**2*d**4 - 18*B*b**3*c*d**5 + + 3*B*b**3*d**6 - c**9 + 9*c**8*d - 36*c**7*d**2 + 84*c**6*d**3 - + 126*c**5*d**4 + 126*c**4*d**5 - 84*c**3*d**6 + 36*c**2*d**7 - + 9*c*d**8 + d**9) + def _t(i): + b = Mul(*info[i][RAD]) + return cancel(rterms[i]/b), Mul(*info[i][BASES]) + aa, AA = _t(0) + bb, BB = _t(1) + cc = -rterms[2] + dd = others + eq = zz.xreplace(dict(zip( + (a, A, b, B, c, d), + (aa, AA, bb, BB, cc, dd)))) + ok = True + # handle power-of-2 cases + if not ok: + if log(lcm, 2).is_Integer and (not others and + len(rterms) == 4 or len(rterms) < 4): + def _norm2(a, b): + return a**2 + b**2 + 2*a*b + + if len(rterms) == 4: + # (r0+r1)**2 - (r2+r3)**2 + r0, r1, r2, r3 = rterms + eq = _norm2(r0, r1) - _norm2(r2, r3) + ok = True + elif len(rterms) == 3: + # (r1+r2)**2 - (r0+others)**2 + r0, r1, r2 = rterms + eq = _norm2(r1, r2) - _norm2(r0, others) + ok = True + elif len(rterms) == 2: + # r0**2 - (r1+others)**2 + r0, r1 = rterms + eq = r0**2 - _norm2(r1, others) + ok = True + + new_depth = sqrt_depth(eq) if ok else depth + rpt += 1 # XXX how many repeats with others unchanging is enough? + if not ok or ( + nwas is not None and len(rterms) == nwas and + new_depth is not None and new_depth == depth and + rpt > 3): + raise NotImplementedError('Cannot remove all radicals') + + flags.update({"cov": cov, "n": len(rterms), "rpt": rpt}) + neq = unrad(eq, *syms, **flags) + if neq: + eq, cov = neq + eq, cov = _canonical(eq, cov) + return eq, cov + + +# delayed imports +from sympy.solvers.bivariate import ( + bivariate_type, _solve_lambert, _filtered_gens) diff --git a/.venv/lib/python3.13/site-packages/sympy/solvers/solveset.py b/.venv/lib/python3.13/site-packages/sympy/solvers/solveset.py new file mode 100644 index 0000000000000000000000000000000000000000..0ae242d9c8c4c0d1c1c46cd968a0c5e547ff0f66 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/solvers/solveset.py @@ -0,0 +1,4131 @@ +""" +This module contains functions to: + + - solve a single equation for a single variable, in any domain either real or complex. + + - solve a single transcendental equation for a single variable in any domain either real or complex. + (currently supports solving in real domain only) + + - solve a system of linear equations with N variables and M equations. + + - solve a system of Non Linear Equations with N variables and M equations +""" +from sympy.core.sympify import sympify +from sympy.core import (S, Pow, Dummy, pi, Expr, Wild, Mul, + Add, Basic) +from sympy.core.containers import Tuple +from sympy.core.function import (Lambda, expand_complex, AppliedUndef, + expand_log, _mexpand, expand_trig, nfloat) +from sympy.core.mod import Mod +from sympy.core.numbers import I, Number, Rational, oo +from sympy.core.intfunc import integer_log +from sympy.core.relational import Eq, Ne, Relational +from sympy.core.sorting import default_sort_key, ordered +from sympy.core.symbol import Symbol, _uniquely_named_symbol +from sympy.core.sympify import _sympify +from sympy.core.traversal import preorder_traversal +from sympy.external.gmpy import gcd as number_gcd, lcm as number_lcm +from sympy.polys.matrices.linsolve import _linear_eq_to_dict +from sympy.polys.polyroots import UnsolvableFactorError +from sympy.simplify.simplify import simplify, fraction, trigsimp, nsimplify +from sympy.simplify import powdenest, logcombine +from sympy.functions import (log, tan, cot, sin, cos, sec, csc, exp, + acos, asin, atan, acot, acsc, asec, + piecewise_fold, Piecewise) +from sympy.functions.combinatorial.numbers import totient +from sympy.functions.elementary.complexes import Abs, arg, re, im +from sympy.functions.elementary.hyperbolic import (HyperbolicFunction, + sinh, cosh, tanh, coth, sech, csch, + asinh, acosh, atanh, acoth, asech, acsch) +from sympy.functions.elementary.miscellaneous import real_root +from sympy.functions.elementary.trigonometric import TrigonometricFunction +from sympy.logic.boolalg import And, BooleanTrue +from sympy.sets import (FiniteSet, imageset, Interval, Intersection, + Union, ConditionSet, ImageSet, Complement, Contains) +from sympy.sets.sets import Set, ProductSet +from sympy.matrices import zeros, Matrix, MatrixBase +from sympy.ntheory.factor_ import divisors +from sympy.ntheory.residue_ntheory import discrete_log, nthroot_mod +from sympy.polys import (roots, Poly, degree, together, PolynomialError, + RootOf, factor, lcm, gcd) +from sympy.polys.polyerrors import CoercionFailed +from sympy.polys.polytools import invert, groebner, poly +from sympy.polys.solvers import (sympy_eqs_to_ring, solve_lin_sys, + PolyNonlinearError) +from sympy.polys.matrices.linsolve import _linsolve +from sympy.solvers.solvers import (checksol, denoms, unrad, + _simple_dens, recast_to_symbols) +from sympy.solvers.polysys import solve_poly_system +from sympy.utilities import filldedent +from sympy.utilities.iterables import (numbered_symbols, has_dups, + is_sequence, iterable) +from sympy.calculus.util import periodicity, continuous_domain, function_range + + +from types import GeneratorType + + +class NonlinearError(ValueError): + """Raised when unexpectedly encountering nonlinear equations""" + pass + + +def _masked(f, *atoms): + """Return ``f``, with all objects given by ``atoms`` replaced with + Dummy symbols, ``d``, and the list of replacements, ``(d, e)``, + where ``e`` is an object of type given by ``atoms`` in which + any other instances of atoms have been recursively replaced with + Dummy symbols, too. The tuples are ordered so that if they are + applied in sequence, the origin ``f`` will be restored. + + Examples + ======== + + >>> from sympy import cos + >>> from sympy.abc import x + >>> from sympy.solvers.solveset import _masked + + >>> f = cos(cos(x) + 1) + >>> f, reps = _masked(cos(1 + cos(x)), cos) + >>> f + _a1 + >>> reps + [(_a1, cos(_a0 + 1)), (_a0, cos(x))] + >>> for d, e in reps: + ... f = f.xreplace({d: e}) + >>> f + cos(cos(x) + 1) + """ + sym = numbered_symbols('a', cls=Dummy, real=True) + mask = [] + for a in ordered(f.atoms(*atoms)): + for i in mask: + a = a.replace(*i) + mask.append((a, next(sym))) + for i, (o, n) in enumerate(mask): + f = f.replace(o, n) + mask[i] = (n, o) + mask = list(reversed(mask)) + return f, mask + + +def _invert(f_x, y, x, domain=S.Complexes): + r""" + Reduce the complex valued equation $f(x) = y$ to a set of equations + + $$\left\{g(x) = h_1(y),\ g(x) = h_2(y),\ \dots,\ g(x) = h_n(y) \right\}$$ + + where $g(x)$ is a simpler function than $f(x)$. The return value is a tuple + $(g(x), \mathrm{set}_h)$, where $g(x)$ is a function of $x$ and $\mathrm{set}_h$ is + the set of function $\left\{h_1(y), h_2(y), \dots, h_n(y)\right\}$. + Here, $y$ is not necessarily a symbol. + + $\mathrm{set}_h$ contains the functions, along with the information + about the domain in which they are valid, through set + operations. For instance, if :math:`y = |x| - n` is inverted + in the real domain, then $\mathrm{set}_h$ is not simply + $\{-n, n\}$ as the nature of `n` is unknown; rather, it is: + + $$ \left(\left[0, \infty\right) \cap \left\{n\right\}\right) \cup + \left(\left(-\infty, 0\right] \cap \left\{- n\right\}\right)$$ + + By default, the complex domain is used which means that inverting even + seemingly simple functions like $\exp(x)$ will give very different + results from those obtained in the real domain. + (In the case of $\exp(x)$, the inversion via $\log$ is multi-valued + in the complex domain, having infinitely many branches.) + + If you are working with real values only (or you are not sure which + function to use) you should probably set the domain to + ``S.Reals`` (or use ``invert_real`` which does that automatically). + + + Examples + ======== + + >>> from sympy.solvers.solveset import invert_complex, invert_real + >>> from sympy.abc import x, y + >>> from sympy import exp + + When does exp(x) == y? + + >>> invert_complex(exp(x), y, x) + (x, ImageSet(Lambda(_n, I*(2*_n*pi + arg(y)) + log(Abs(y))), Integers)) + >>> invert_real(exp(x), y, x) + (x, Intersection({log(y)}, Reals)) + + When does exp(x) == 1? + + >>> invert_complex(exp(x), 1, x) + (x, ImageSet(Lambda(_n, 2*_n*I*pi), Integers)) + >>> invert_real(exp(x), 1, x) + (x, {0}) + + See Also + ======== + invert_real, invert_complex + """ + x = sympify(x) + if not x.is_Symbol: + raise ValueError("x must be a symbol") + f_x = sympify(f_x) + if x not in f_x.free_symbols: + raise ValueError("Inverse of constant function doesn't exist") + y = sympify(y) + if x in y.free_symbols: + raise ValueError("y should be independent of x ") + + if domain.is_subset(S.Reals): + x1, s = _invert_real(f_x, FiniteSet(y), x) + else: + x1, s = _invert_complex(f_x, FiniteSet(y), x) + + # f couldn't be inverted completely; return unmodified. + if x1 != x: + return x1, s + + # Avoid adding gratuitous intersections with S.Complexes. Actual + # conditions should be handled by the respective inverters. + if domain is S.Complexes: + return x1, s + + if isinstance(s, FiniteSet): + return x1, s.intersect(domain) + + # "Fancier" solution sets like those obtained by inversion of trigonometric + # functions already include general validity conditions (i.e. conditions on + # the domain of the respective inverse functions), so we should avoid adding + # blanket intersections with S.Reals. But subsets of R (or C) must still be + # accounted for. + if domain is S.Reals: + return x1, s + else: + return x1, s.intersect(domain) + + +invert_complex = _invert + + +def invert_real(f_x, y, x): + """ + Inverts a real-valued function. Same as :func:`invert_complex`, but sets + the domain to ``S.Reals`` before inverting. + """ + return _invert(f_x, y, x, S.Reals) + + +def _invert_real(f, g_ys, symbol): + """Helper function for _invert.""" + + if f == symbol or g_ys is S.EmptySet: + return (symbol, g_ys) + + n = Dummy('n', real=True) + + if isinstance(f, exp) or (f.is_Pow and f.base == S.Exp1): + return _invert_real(f.exp, + imageset(Lambda(n, log(n)), g_ys), + symbol) + + if hasattr(f, 'inverse') and f.inverse() is not None and not isinstance(f, ( + TrigonometricFunction, + HyperbolicFunction, + )): + if len(f.args) > 1: + raise ValueError("Only functions with one argument are supported.") + return _invert_real(f.args[0], + imageset(Lambda(n, f.inverse()(n)), g_ys), + symbol) + + if isinstance(f, Abs): + return _invert_abs(f.args[0], g_ys, symbol) + + if f.is_Add: + # f = g + h + g, h = f.as_independent(symbol) + if g is not S.Zero: + return _invert_real(h, imageset(Lambda(n, n - g), g_ys), symbol) + + if f.is_Mul: + # f = g*h + g, h = f.as_independent(symbol) + + if g is not S.One: + return _invert_real(h, imageset(Lambda(n, n/g), g_ys), symbol) + + if f.is_Pow: + base, expo = f.args + base_has_sym = base.has(symbol) + expo_has_sym = expo.has(symbol) + + if not expo_has_sym: + + if expo.is_rational: + num, den = expo.as_numer_denom() + + if den % 2 == 0 and num % 2 == 1 and den.is_zero is False: + # Here we have f(x)**(num/den) = y + # where den is nonzero and even and y is an element + # of the set g_ys. + # den is even, so we are only interested in the cases + # where both f(x) and y are positive. + # Restricting y to be positive (using the set g_ys_pos) + # means that y**(den/num) is always positive. + # Therefore it isn't necessary to also constrain f(x) + # to be positive because we are only going to + # find solutions of f(x) = y**(d/n) + # where the rhs is already required to be positive. + root = Lambda(n, real_root(n, expo)) + g_ys_pos = g_ys & Interval(0, oo) + res = imageset(root, g_ys_pos) + _inv, _set = _invert_real(base, res, symbol) + return (_inv, _set) + + if den % 2 == 1: + root = Lambda(n, real_root(n, expo)) + res = imageset(root, g_ys) + if num % 2 == 0: + neg_res = imageset(Lambda(n, -n), res) + return _invert_real(base, res + neg_res, symbol) + if num % 2 == 1: + return _invert_real(base, res, symbol) + + elif expo.is_irrational: + root = Lambda(n, real_root(n, expo)) + g_ys_pos = g_ys & Interval(0, oo) + res = imageset(root, g_ys_pos) + return _invert_real(base, res, symbol) + + else: + # indeterminate exponent, e.g. Float or parity of + # num, den of rational could not be determined + pass # use default return + + if not base_has_sym: + rhs = g_ys.args[0] + if base.is_positive: + return _invert_real(expo, + imageset(Lambda(n, log(n, base, evaluate=False)), g_ys), symbol) + elif base.is_negative: + s, b = integer_log(rhs, base) + if b: + return _invert_real(expo, FiniteSet(s), symbol) + else: + return (expo, S.EmptySet) + elif base.is_zero: + one = Eq(rhs, 1) + if one == S.true: + # special case: 0**x - 1 + return _invert_real(expo, FiniteSet(0), symbol) + elif one == S.false: + return (expo, S.EmptySet) + + if isinstance(f, (TrigonometricFunction, HyperbolicFunction)): + return _invert_trig_hyp_real(f, g_ys, symbol) + + return (f, g_ys) + + +# Dictionaries of inverses will be cached after first use. +_trig_inverses = None +_hyp_inverses = None + +def _invert_trig_hyp_real(f, g_ys, symbol): + """Helper function for inverting trigonometric and hyperbolic functions. + + This helper only handles inversion over the reals. + + For trigonometric functions only finite `g_ys` sets are implemented. + + For hyperbolic functions the set `g_ys` is checked against the domain of the + respective inverse functions. Infinite `g_ys` sets are also supported. + """ + + if isinstance(f, HyperbolicFunction): + n = Dummy('n', real=True) + + if isinstance(f, sinh): + # asinh is defined over R. + return _invert_real(f.args[0], imageset(n, asinh(n), g_ys), symbol) + + if isinstance(f, cosh): + g_ys_dom = g_ys.intersect(Interval(1, oo)) + if isinstance(g_ys_dom, Intersection): + # could not properly resolve domain check + if isinstance(g_ys, FiniteSet): + # If g_ys is a `FiniteSet`` it should be sufficient to just + # let the calling `_invert_real()` add an intersection with + # `S.Reals` (or a subset `domain`) to ensure that only valid + # (real) solutions are returned. + # This avoids adding "too many" Intersections or + # ConditionSets in the returned set. + g_ys_dom = g_ys + else: + return (f, g_ys) + return _invert_real(f.args[0], Union( + imageset(n, acosh(n), g_ys_dom), + imageset(n, -acosh(n), g_ys_dom)), symbol) + + if isinstance(f, sech): + g_ys_dom = g_ys.intersect(Interval.Lopen(0, 1)) + if isinstance(g_ys_dom, Intersection): + if isinstance(g_ys, FiniteSet): + g_ys_dom = g_ys + else: + return (f, g_ys) + return _invert_real(f.args[0], Union( + imageset(n, asech(n), g_ys_dom), + imageset(n, -asech(n), g_ys_dom)), symbol) + + if isinstance(f, tanh): + g_ys_dom = g_ys.intersect(Interval.open(-1, 1)) + if isinstance(g_ys_dom, Intersection): + if isinstance(g_ys, FiniteSet): + g_ys_dom = g_ys + else: + return (f, g_ys) + return _invert_real(f.args[0], + imageset(n, atanh(n), g_ys_dom), symbol) + + if isinstance(f, coth): + g_ys_dom = g_ys - Interval(-1, 1) + if isinstance(g_ys_dom, Complement): + if isinstance(g_ys, FiniteSet): + g_ys_dom = g_ys + else: + return (f, g_ys) + return _invert_real(f.args[0], + imageset(n, acoth(n), g_ys_dom), symbol) + + if isinstance(f, csch): + g_ys_dom = g_ys - FiniteSet(0) + if isinstance(g_ys_dom, Complement): + if isinstance(g_ys, FiniteSet): + g_ys_dom = g_ys + else: + return (f, g_ys) + return _invert_real(f.args[0], + imageset(n, acsch(n), g_ys_dom), symbol) + + elif isinstance(f, TrigonometricFunction) and isinstance(g_ys, FiniteSet): + def _get_trig_inverses(func): + global _trig_inverses + if _trig_inverses is None: + _trig_inverses = { + sin : ((asin, lambda y: pi-asin(y)), 2*pi, Interval(-1, 1)), + cos : ((acos, lambda y: -acos(y)), 2*pi, Interval(-1, 1)), + tan : ((atan,), pi, S.Reals), + cot : ((acot,), pi, S.Reals), + sec : ((asec, lambda y: -asec(y)), 2*pi, + Union(Interval(-oo, -1), Interval(1, oo))), + csc : ((acsc, lambda y: pi-acsc(y)), 2*pi, + Union(Interval(-oo, -1), Interval(1, oo)))} + return _trig_inverses[func] + + invs, period, rng = _get_trig_inverses(f.func) + n = Dummy('n', integer=True) + def create_return_set(g): + # returns ConditionSet that will be part of the final (x, set) tuple + invsimg = Union(*[ + imageset(n, period*n + inv(g), S.Integers) for inv in invs]) + inv_f, inv_g_ys = _invert_real(f.args[0], invsimg, symbol) + if inv_f == symbol: # inversion successful + conds = rng.contains(g) + return ConditionSet(symbol, conds, inv_g_ys) + else: + return ConditionSet(symbol, Eq(f, g), S.Reals) + + retset = Union(*[create_return_set(g) for g in g_ys]) + return (symbol, retset) + + else: + return (f, g_ys) + + +def _invert_trig_hyp_complex(f, g_ys, symbol): + """Helper function for inverting trigonometric and hyperbolic functions. + + This helper only handles inversion over the complex numbers. + Only finite `g_ys` sets are implemented. + + Handling of singularities is only implemented for hyperbolic equations. + In case of a symbolic element g in g_ys a ConditionSet may be returned. + """ + + if isinstance(f, TrigonometricFunction) and isinstance(g_ys, FiniteSet): + def inv(trig): + if isinstance(trig, (sin, csc)): + F = asin if isinstance(trig, sin) else acsc + return ( + lambda a: 2*n*pi + F(a), + lambda a: 2*n*pi + pi - F(a)) + if isinstance(trig, (cos, sec)): + F = acos if isinstance(trig, cos) else asec + return ( + lambda a: 2*n*pi + F(a), + lambda a: 2*n*pi - F(a)) + if isinstance(trig, (tan, cot)): + return (lambda a: n*pi + trig.inverse()(a),) + + n = Dummy('n', integer=True) + invs = S.EmptySet + for L in inv(f): + invs += Union(*[imageset(Lambda(n, L(g)), S.Integers) for g in g_ys]) + return _invert_complex(f.args[0], invs, symbol) + + elif isinstance(f, HyperbolicFunction) and isinstance(g_ys, FiniteSet): + # There are two main options regarding singularities / domain checking + # for symbolic elements in g_ys: + # 1. Add a "catch-all" intersection with S.Complexes. + # 2. ConditionSets. + # At present ConditionSets seem to work better and have the additional + # benefit of representing the precise conditions that must be satisfied. + # The conditions are also rather straightforward. (At most two isolated + # points.) + def _get_hyp_inverses(func): + global _hyp_inverses + if _hyp_inverses is None: + _hyp_inverses = { + sinh : ((asinh, lambda y: I*pi-asinh(y)), 2*I*pi, ()), + cosh : ((acosh, lambda y: -acosh(y)), 2*I*pi, ()), + tanh : ((atanh,), I*pi, (-1, 1)), + coth : ((acoth,), I*pi, (-1, 1)), + sech : ((asech, lambda y: -asech(y)), 2*I*pi, (0, )), + csch : ((acsch, lambda y: I*pi-acsch(y)), 2*I*pi, (0, ))} + return _hyp_inverses[func] + + # invs: iterable of main inverses, e.g. (acosh, -acosh). + # excl: iterable of singularities to be checked for. + invs, period, excl = _get_hyp_inverses(f.func) + n = Dummy('n', integer=True) + def create_return_set(g): + # returns ConditionSet that will be part of the final (x, set) tuple + invsimg = Union(*[ + imageset(n, period*n + inv(g), S.Integers) for inv in invs]) + inv_f, inv_g_ys = _invert_complex(f.args[0], invsimg, symbol) + if inv_f == symbol: # inversion successful + conds = And(*[Ne(g, e) for e in excl]) + return ConditionSet(symbol, conds, inv_g_ys) + else: + return ConditionSet(symbol, Eq(f, g), S.Complexes) + + retset = Union(*[create_return_set(g) for g in g_ys]) + return (symbol, retset) + + else: + return (f, g_ys) + + +def _invert_complex(f, g_ys, symbol): + """Helper function for _invert.""" + + if f == symbol or g_ys is S.EmptySet: + return (symbol, g_ys) + + n = Dummy('n') + + if f.is_Add: + # f = g + h + g, h = f.as_independent(symbol) + if g is not S.Zero: + return _invert_complex(h, imageset(Lambda(n, n - g), g_ys), symbol) + + if f.is_Mul: + # f = g*h + g, h = f.as_independent(symbol) + + if g is not S.One: + if g in {S.NegativeInfinity, S.ComplexInfinity, S.Infinity}: + return (h, S.EmptySet) + return _invert_complex(h, imageset(Lambda(n, n/g), g_ys), symbol) + + if f.is_Pow: + base, expo = f.args + # special case: g**r = 0 + # Could be improved like `_invert_real` to handle more general cases. + if expo.is_Rational and g_ys == FiniteSet(0): + if expo.is_positive: + return _invert_complex(base, g_ys, symbol) + + if hasattr(f, 'inverse') and f.inverse() is not None and \ + not isinstance(f, TrigonometricFunction) and \ + not isinstance(f, HyperbolicFunction) and \ + not isinstance(f, exp): + if len(f.args) > 1: + raise ValueError("Only functions with one argument are supported.") + return _invert_complex(f.args[0], + imageset(Lambda(n, f.inverse()(n)), g_ys), symbol) + + if isinstance(f, exp) or (f.is_Pow and f.base == S.Exp1): + if isinstance(g_ys, ImageSet): + # can solve up to `(d*exp(exp(...(exp(a*x + b))...) + c)` format. + # Further can be improved to `(d*exp(exp(...(exp(a*x**n + b*x**(n-1) + ... + f))...) + c)`. + g_ys_expr = g_ys.lamda.expr + g_ys_vars = g_ys.lamda.variables + k = Dummy('k{}'.format(len(g_ys_vars))) + g_ys_vars_1 = (k,) + g_ys_vars + exp_invs = Union(*[imageset(Lambda((g_ys_vars_1,), (I*(2*k*pi + arg(g_ys_expr)) + + log(Abs(g_ys_expr)))), S.Integers**(len(g_ys_vars_1)))]) + return _invert_complex(f.exp, exp_invs, symbol) + + elif isinstance(g_ys, FiniteSet): + exp_invs = Union(*[imageset(Lambda(n, I*(2*n*pi + arg(g_y)) + + log(Abs(g_y))), S.Integers) + for g_y in g_ys if g_y != 0]) + return _invert_complex(f.exp, exp_invs, symbol) + + if isinstance(f, (TrigonometricFunction, HyperbolicFunction)): + return _invert_trig_hyp_complex(f, g_ys, symbol) + + return (f, g_ys) + + +def _invert_abs(f, g_ys, symbol): + """Helper function for inverting absolute value functions. + + Returns the complete result of inverting an absolute value + function along with the conditions which must also be satisfied. + + If it is certain that all these conditions are met, a :class:`~.FiniteSet` + of all possible solutions is returned. If any condition cannot be + satisfied, an :class:`~.EmptySet` is returned. Otherwise, a + :class:`~.ConditionSet` of the solutions, with all the required conditions + specified, is returned. + + """ + if not g_ys.is_FiniteSet: + # this could be used for FiniteSet, but the + # results are more compact if they aren't, e.g. + # ConditionSet(x, Contains(n, Interval(0, oo)), {-n, n}) vs + # Union(Intersection(Interval(0, oo), {n}), Intersection(Interval(-oo, 0), {-n})) + # for the solution of abs(x) - n + pos = Intersection(g_ys, Interval(0, S.Infinity)) + parg = _invert_real(f, pos, symbol) + narg = _invert_real(-f, pos, symbol) + if parg[0] != narg[0]: + raise NotImplementedError + return parg[0], Union(narg[1], parg[1]) + + # check conditions: all these must be true. If any are unknown + # then return them as conditions which must be satisfied + unknown = [] + for a in g_ys.args: + ok = a.is_nonnegative if a.is_Number else a.is_positive + if ok is None: + unknown.append(a) + elif not ok: + return symbol, S.EmptySet + if unknown: + conditions = And(*[Contains(i, Interval(0, oo)) + for i in unknown]) + else: + conditions = True + n = Dummy('n', real=True) + # this is slightly different than above: instead of solving + # +/-f on positive values, here we solve for f on +/- g_ys + g_x, values = _invert_real(f, Union( + imageset(Lambda(n, n), g_ys), + imageset(Lambda(n, -n), g_ys)), symbol) + return g_x, ConditionSet(g_x, conditions, values) + + +def domain_check(f, symbol, p): + """Returns False if point p is infinite or any subexpression of f + is infinite or becomes so after replacing symbol with p. If none of + these conditions is met then True will be returned. + + Examples + ======== + + >>> from sympy import Mul, oo + >>> from sympy.abc import x + >>> from sympy.solvers.solveset import domain_check + >>> g = 1/(1 + (1/(x + 1))**2) + >>> domain_check(g, x, -1) + False + >>> domain_check(x**2, x, 0) + True + >>> domain_check(1/x, x, oo) + False + + * The function relies on the assumption that the original form + of the equation has not been changed by automatic simplification. + + >>> domain_check(x/x, x, 0) # x/x is automatically simplified to 1 + True + + * To deal with automatic evaluations use evaluate=False: + + >>> domain_check(Mul(x, 1/x, evaluate=False), x, 0) + False + """ + f, p = sympify(f), sympify(p) + if p.is_infinite: + return False + return _domain_check(f, symbol, p) + + +def _domain_check(f, symbol, p): + # helper for domain check + if f.is_Atom and f.is_finite: + return True + elif f.subs(symbol, p).is_infinite: + return False + elif isinstance(f, Piecewise): + # Check the cases of the Piecewise in turn. There might be invalid + # expressions in later cases that don't apply e.g. + # solveset(Piecewise((0, Eq(x, 0)), (1/x, True)), x) + for expr, cond in f.args: + condsubs = cond.subs(symbol, p) + if condsubs is S.false: + continue + elif condsubs is S.true: + return _domain_check(expr, symbol, p) + else: + # We don't know which case of the Piecewise holds. On this + # basis we cannot decide whether any solution is in or out of + # the domain. Ideally this function would allow returning a + # symbolic condition for the validity of the solution that + # could be handled in the calling code. In the mean time we'll + # give this particular solution the benefit of the doubt and + # let it pass. + return True + else: + # TODO : We should not blindly recurse through all args of arbitrary expressions like this + return all(_domain_check(g, symbol, p) + for g in f.args) + + +def _is_finite_with_finite_vars(f, domain=S.Complexes): + """ + Return True if the given expression is finite. For symbols that + do not assign a value for `complex` and/or `real`, the domain will + be used to assign a value; symbols that do not assign a value + for `finite` will be made finite. All other assumptions are + left unmodified. + """ + def assumptions(s): + A = s.assumptions0 + A.setdefault('finite', A.get('finite', True)) + if domain.is_subset(S.Reals): + # if this gets set it will make complex=True, too + A.setdefault('real', True) + else: + # don't change 'real' because being complex implies + # nothing about being real + A.setdefault('complex', True) + return A + + reps = {s: Dummy(**assumptions(s)) for s in f.free_symbols} + return f.xreplace(reps).is_finite + + +def _is_function_class_equation(func_class, f, symbol): + """ Tests whether the equation is an equation of the given function class. + + The given equation belongs to the given function class if it is + comprised of functions of the function class which are multiplied by + or added to expressions independent of the symbol. In addition, the + arguments of all such functions must be linear in the symbol as well. + + Examples + ======== + + >>> from sympy.solvers.solveset import _is_function_class_equation + >>> from sympy import tan, sin, tanh, sinh, exp + >>> from sympy.abc import x + >>> from sympy.functions.elementary.trigonometric import TrigonometricFunction + >>> from sympy.functions.elementary.hyperbolic import HyperbolicFunction + >>> _is_function_class_equation(TrigonometricFunction, exp(x) + tan(x), x) + False + >>> _is_function_class_equation(TrigonometricFunction, tan(x) + sin(x), x) + True + >>> _is_function_class_equation(TrigonometricFunction, tan(x**2), x) + False + >>> _is_function_class_equation(TrigonometricFunction, tan(x + 2), x) + True + >>> _is_function_class_equation(HyperbolicFunction, tanh(x) + sinh(x), x) + True + """ + if f.is_Mul or f.is_Add: + return all(_is_function_class_equation(func_class, arg, symbol) + for arg in f.args) + + if f.is_Pow: + if not f.exp.has(symbol): + return _is_function_class_equation(func_class, f.base, symbol) + else: + return False + + if not f.has(symbol): + return True + + if isinstance(f, func_class): + try: + g = Poly(f.args[0], symbol) + return g.degree() <= 1 + except PolynomialError: + return False + else: + return False + + +def _solve_as_rational(f, symbol, domain): + """ solve rational functions""" + f = together(_mexpand(f, recursive=True), deep=True) + g, h = fraction(f) + if not h.has(symbol): + try: + return _solve_as_poly(g, symbol, domain) + except NotImplementedError: + # The polynomial formed from g could end up having + # coefficients in a ring over which finding roots + # isn't implemented yet, e.g. ZZ[a] for some symbol a + return ConditionSet(symbol, Eq(f, 0), domain) + except CoercionFailed: + # contained oo, zoo or nan + return S.EmptySet + else: + valid_solns = _solveset(g, symbol, domain) + invalid_solns = _solveset(h, symbol, domain) + return valid_solns - invalid_solns + + +class _SolveTrig1Error(Exception): + """Raised when _solve_trig1 heuristics do not apply""" + +def _solve_trig(f, symbol, domain): + """Function to call other helpers to solve trigonometric equations """ + # If f is composed of a single trig function (potentially appearing multiple + # times) we should solve by either inverting directly or inverting after a + # suitable change of variable. + # + # _solve_trig is currently only called by _solveset for trig/hyperbolic + # functions of an argument linear in x. Inverting a symbolic argument should + # include a guard against division by zero in order to have a result that is + # consistent with similar processing done by _solve_trig1. + # (Ideally _invert should add these conditions by itself.) + trig_expr, count = None, 0 + for expr in preorder_traversal(f): + if isinstance(expr, (TrigonometricFunction, + HyperbolicFunction)) and expr.has(symbol): + if not trig_expr: + trig_expr, count = expr, 1 + elif expr == trig_expr: + count += 1 + else: + trig_expr, count = False, 0 + break + if count == 1: + # direct inversion + x, sol = _invert(f, 0, symbol, domain) + if x == symbol: + cond = True + if trig_expr.free_symbols - {symbol}: + a, h = trig_expr.args[0].as_independent(symbol, as_Add=True) + m, h = h.as_independent(symbol, as_Add=False) + num, den = m.as_numer_denom() + cond = Ne(num, 0) & Ne(den, 0) + return ConditionSet(symbol, cond, sol) + else: + return ConditionSet(symbol, Eq(f, 0), domain) + elif count: + # solve by change of variable + y = Dummy('y') + f_cov = f.subs(trig_expr, y) + sol_cov = solveset(f_cov, y, domain) + if isinstance(sol_cov, FiniteSet): + return Union( + *[_solve_trig(trig_expr-s, symbol, domain) for s in sol_cov]) + + sol = None + try: + # multiple trig/hyp functions; solve by rewriting to exp + sol = _solve_trig1(f, symbol, domain) + except _SolveTrig1Error: + try: + # multiple trig/hyp functions; solve by rewriting to tan(x/2) + sol = _solve_trig2(f, symbol, domain) + except ValueError: + raise NotImplementedError(filldedent(''' + Solution to this kind of trigonometric equations + is yet to be implemented''')) + return sol + + +def _solve_trig1(f, symbol, domain): + """Primary solver for trigonometric and hyperbolic equations + + Returns either the solution set as a ConditionSet (auto-evaluated to a + union of ImageSets if no variables besides 'symbol' are involved) or + raises _SolveTrig1Error if f == 0 cannot be solved. + + Notes + ===== + Algorithm: + 1. Do a change of variable x -> mu*x in arguments to trigonometric and + hyperbolic functions, in order to reduce them to small integers. (This + step is crucial to keep the degrees of the polynomials of step 4 low.) + 2. Rewrite trigonometric/hyperbolic functions as exponentials. + 3. Proceed to a 2nd change of variable, replacing exp(I*x) or exp(x) by y. + 4. Solve the resulting rational equation. + 5. Use invert_complex or invert_real to return to the original variable. + 6. If the coefficients of 'symbol' were symbolic in nature, add the + necessary consistency conditions in a ConditionSet. + + """ + # Prepare change of variable + x = Dummy('x') + if _is_function_class_equation(HyperbolicFunction, f, symbol): + cov = exp(x) + inverter = invert_real if domain.is_subset(S.Reals) else invert_complex + else: + cov = exp(I*x) + inverter = invert_complex + + f = trigsimp(f) + f_original = f + trig_functions = f.atoms(TrigonometricFunction, HyperbolicFunction) + trig_arguments = [e.args[0] for e in trig_functions] + # trigsimp may have reduced the equation to an expression + # that is independent of 'symbol' (e.g. cos**2+sin**2) + if not any(a.has(symbol) for a in trig_arguments): + return solveset(f_original, symbol, domain) + + denominators = [] + numerators = [] + for ar in trig_arguments: + try: + poly_ar = Poly(ar, symbol) + except PolynomialError: + raise _SolveTrig1Error("trig argument is not a polynomial") + if poly_ar.degree() > 1: # degree >1 still bad + raise _SolveTrig1Error("degree of variable must not exceed one") + if poly_ar.degree() == 0: # degree 0, don't care + continue + c = poly_ar.all_coeffs()[0] # got the coefficient of 'symbol' + numerators.append(fraction(c)[0]) + denominators.append(fraction(c)[1]) + + mu = lcm(denominators)/gcd(numerators) + f = f.subs(symbol, mu*x) + f = f.rewrite(exp) + f = together(f) + g, h = fraction(f) + y = Dummy('y') + g, h = g.expand(), h.expand() + g, h = g.subs(cov, y), h.subs(cov, y) + if g.has(x) or h.has(x): + raise _SolveTrig1Error("change of variable not possible") + + solns = solveset_complex(g, y) - solveset_complex(h, y) + if isinstance(solns, ConditionSet): + raise _SolveTrig1Error("polynomial has ConditionSet solution") + + if isinstance(solns, FiniteSet): + if any(isinstance(s, RootOf) for s in solns): + raise _SolveTrig1Error("polynomial results in RootOf object") + # revert the change of variable + cov = cov.subs(x, symbol/mu) + result = Union(*[inverter(cov, s, symbol)[1] for s in solns]) + # In case of symbolic coefficients, the solution set is only valid + # if numerator and denominator of mu are non-zero. + if mu.has(Symbol): + syms = (mu).atoms(Symbol) + munum, muden = fraction(mu) + condnum = munum.as_independent(*syms, as_Add=False)[1] + condden = muden.as_independent(*syms, as_Add=False)[1] + cond = And(Ne(condnum, 0), Ne(condden, 0)) + else: + cond = True + # Actual conditions are returned as part of the ConditionSet. Adding an + # intersection with C would only complicate some solution sets due to + # current limitations of intersection code. (e.g. #19154) + if domain is S.Complexes: + # This is a slight abuse of ConditionSet. Ideally this should + # be some kind of "PiecewiseSet". (See #19507 discussion) + return ConditionSet(symbol, cond, result) + else: + return ConditionSet(symbol, cond, Intersection(result, domain)) + elif solns is S.EmptySet: + return S.EmptySet + else: + raise _SolveTrig1Error("polynomial solutions must form FiniteSet") + + +def _solve_trig2(f, symbol, domain): + """Secondary helper to solve trigonometric equations, + called when first helper fails """ + f = trigsimp(f) + f_original = f + trig_functions = f.atoms(sin, cos, tan, sec, cot, csc) + trig_arguments = [e.args[0] for e in trig_functions] + denominators = [] + numerators = [] + + # todo: This solver can be extended to hyperbolics if the + # analogous change of variable to tanh (instead of tan) + # is used. + if not trig_functions: + return ConditionSet(symbol, Eq(f_original, 0), domain) + + # todo: The pre-processing below (extraction of numerators, denominators, + # gcd, lcm, mu, etc.) should be updated to the enhanced version in + # _solve_trig1. (See #19507) + for ar in trig_arguments: + try: + poly_ar = Poly(ar, symbol) + except PolynomialError: + raise ValueError("give up, we cannot solve if this is not a polynomial in x") + if poly_ar.degree() > 1: # degree >1 still bad + raise ValueError("degree of variable inside polynomial should not exceed one") + if poly_ar.degree() == 0: # degree 0, don't care + continue + c = poly_ar.all_coeffs()[0] # got the coefficient of 'symbol' + try: + numerators.append(Rational(c).p) + denominators.append(Rational(c).q) + except TypeError: + return ConditionSet(symbol, Eq(f_original, 0), domain) + + x = Dummy('x') + + mu = Rational(2)*number_lcm(*denominators)/number_gcd(*numerators) + f = f.subs(symbol, mu*x) + f = f.rewrite(tan) + f = expand_trig(f) + f = together(f) + + g, h = fraction(f) + y = Dummy('y') + g, h = g.expand(), h.expand() + g, h = g.subs(tan(x), y), h.subs(tan(x), y) + + if g.has(x) or h.has(x): + return ConditionSet(symbol, Eq(f_original, 0), domain) + solns = solveset(g, y, S.Reals) - solveset(h, y, S.Reals) + + if isinstance(solns, FiniteSet): + result = Union(*[invert_real(tan(symbol/mu), s, symbol)[1] + for s in solns]) + dsol = invert_real(tan(symbol/mu), oo, symbol)[1] + if degree(h) > degree(g): # If degree(denom)>degree(num) then there + result = Union(result, dsol) # would be another sol at Lim(denom-->oo) + return Intersection(result, domain) + elif solns is S.EmptySet: + return S.EmptySet + else: + return ConditionSet(symbol, Eq(f_original, 0), S.Reals) + + +def _solve_as_poly(f, symbol, domain=S.Complexes): + """ + Solve the equation using polynomial techniques if it already is a + polynomial equation or, with a change of variables, can be made so. + """ + result = None + if f.is_polynomial(symbol): + solns = roots(f, symbol, cubics=True, quartics=True, + quintics=True, domain='EX') + num_roots = sum(solns.values()) + if degree(f, symbol) <= num_roots: + result = FiniteSet(*solns.keys()) + else: + poly = Poly(f, symbol) + solns = poly.all_roots() + if poly.degree() <= len(solns): + result = FiniteSet(*solns) + else: + result = ConditionSet(symbol, Eq(f, 0), domain) + else: + poly = Poly(f) + if poly is None: + result = ConditionSet(symbol, Eq(f, 0), domain) + gens = [g for g in poly.gens if g.has(symbol)] + + if len(gens) == 1: + poly = Poly(poly, gens[0]) + gen = poly.gen + deg = poly.degree() + poly = Poly(poly.as_expr(), poly.gen, composite=True) + poly_solns = FiniteSet(*roots(poly, cubics=True, quartics=True, + quintics=True).keys()) + + if len(poly_solns) < deg: + result = ConditionSet(symbol, Eq(f, 0), domain) + + if gen != symbol: + y = Dummy('y') + inverter = invert_real if domain.is_subset(S.Reals) else invert_complex + lhs, rhs_s = inverter(gen, y, symbol) + if lhs == symbol: + result = Union(*[rhs_s.subs(y, s) for s in poly_solns]) + if isinstance(result, FiniteSet) and isinstance(gen, Pow + ) and gen.base.is_Rational: + result = FiniteSet(*[expand_log(i) for i in result]) + else: + result = ConditionSet(symbol, Eq(f, 0), domain) + else: + result = ConditionSet(symbol, Eq(f, 0), domain) + + if result is not None: + if isinstance(result, FiniteSet): + # this is to simplify solutions like -sqrt(-I) to sqrt(2)/2 + # - sqrt(2)*I/2. We are not expanding for solution with symbols + # or undefined functions because that makes the solution more complicated. + # For example, expand_complex(a) returns re(a) + I*im(a) + if all(s.atoms(Symbol, AppliedUndef) == set() and not isinstance(s, RootOf) + for s in result): + s = Dummy('s') + result = imageset(Lambda(s, expand_complex(s)), result) + if isinstance(result, FiniteSet) and domain != S.Complexes: + # Avoid adding gratuitous intersections with S.Complexes. Actual + # conditions should be handled elsewhere. + result = result.intersection(domain) + return result + else: + return ConditionSet(symbol, Eq(f, 0), domain) + + +def _solve_radical(f, unradf, symbol, solveset_solver): + """ Helper function to solve equations with radicals """ + res = unradf + eq, cov = res if res else (f, []) + if not cov: + result = solveset_solver(eq, symbol) - \ + Union(*[solveset_solver(g, symbol) for g in denoms(f, symbol)]) + else: + y, yeq = cov + if not solveset_solver(y - I, y): + yreal = Dummy('yreal', real=True) + yeq = yeq.xreplace({y: yreal}) + eq = eq.xreplace({y: yreal}) + y = yreal + g_y_s = solveset_solver(yeq, symbol) + f_y_sols = solveset_solver(eq, y) + result = Union(*[imageset(Lambda(y, g_y), f_y_sols) + for g_y in g_y_s]) + + def check_finiteset(solutions): + f_set = [] # solutions for FiniteSet + c_set = [] # solutions for ConditionSet + for s in solutions: + if checksol(f, symbol, s): + f_set.append(s) + else: + c_set.append(s) + return FiniteSet(*f_set) + ConditionSet(symbol, Eq(f, 0), FiniteSet(*c_set)) + + def check_set(solutions): + if solutions is S.EmptySet: + return solutions + elif isinstance(solutions, ConditionSet): + # XXX: Maybe the base set should be checked? + return solutions + elif isinstance(solutions, FiniteSet): + return check_finiteset(solutions) + elif isinstance(solutions, Complement): + A, B = solutions.args + return Complement(check_set(A), B) + elif isinstance(solutions, Union): + return Union(*[check_set(s) for s in solutions.args]) + else: + # XXX: There should be more cases checked here. The cases above + # are all those that come up in the test suite for now. + return solutions + + solution_set = check_set(result) + + return solution_set + + +def _solve_abs(f, symbol, domain): + """ Helper function to solve equation involving absolute value function """ + if not domain.is_subset(S.Reals): + raise ValueError(filldedent(''' + Absolute values cannot be inverted in the + complex domain.''')) + p, q, r = Wild('p'), Wild('q'), Wild('r') + pattern_match = f.match(p*Abs(q) + r) or {} + f_p, f_q, f_r = [pattern_match.get(i, S.Zero) for i in (p, q, r)] + + if not (f_p.is_zero or f_q.is_zero): + domain = continuous_domain(f_q, symbol, domain) + from .inequalities import solve_univariate_inequality + q_pos_cond = solve_univariate_inequality(f_q >= 0, symbol, + relational=False, domain=domain, continuous=True) + q_neg_cond = q_pos_cond.complement(domain) + + sols_q_pos = solveset_real(f_p*f_q + f_r, + symbol).intersect(q_pos_cond) + sols_q_neg = solveset_real(f_p*(-f_q) + f_r, + symbol).intersect(q_neg_cond) + return Union(sols_q_pos, sols_q_neg) + else: + return ConditionSet(symbol, Eq(f, 0), domain) + + +def solve_decomposition(f, symbol, domain): + """ + Function to solve equations via the principle of "Decomposition + and Rewriting". + + Examples + ======== + >>> from sympy import exp, sin, Symbol, pprint, S + >>> from sympy.solvers.solveset import solve_decomposition as sd + >>> x = Symbol('x') + >>> f1 = exp(2*x) - 3*exp(x) + 2 + >>> sd(f1, x, S.Reals) + {0, log(2)} + >>> f2 = sin(x)**2 + 2*sin(x) + 1 + >>> pprint(sd(f2, x, S.Reals), use_unicode=False) + 3*pi + {2*n*pi + ---- | n in Integers} + 2 + >>> f3 = sin(x + 2) + >>> pprint(sd(f3, x, S.Reals), use_unicode=False) + {2*n*pi - 2 | n in Integers} U {2*n*pi - 2 + pi | n in Integers} + + """ + from sympy.solvers.decompogen import decompogen + # decompose the given function + g_s = decompogen(f, symbol) + # `y_s` represents the set of values for which the function `g` is to be + # solved. + # `solutions` represent the solutions of the equations `g = y_s` or + # `g = 0` depending on the type of `y_s`. + # As we are interested in solving the equation: f = 0 + y_s = FiniteSet(0) + for g in g_s: + frange = function_range(g, symbol, domain) + y_s = Intersection(frange, y_s) + result = S.EmptySet + if isinstance(y_s, FiniteSet): + for y in y_s: + solutions = solveset(Eq(g, y), symbol, domain) + if not isinstance(solutions, ConditionSet): + result += solutions + + else: + if isinstance(y_s, ImageSet): + iter_iset = (y_s,) + + elif isinstance(y_s, Union): + iter_iset = y_s.args + + elif y_s is S.EmptySet: + # y_s is not in the range of g in g_s, so no solution exists + #in the given domain + return S.EmptySet + + for iset in iter_iset: + new_solutions = solveset(Eq(iset.lamda.expr, g), symbol, domain) + dummy_var = tuple(iset.lamda.expr.free_symbols)[0] + (base_set,) = iset.base_sets + if isinstance(new_solutions, FiniteSet): + new_exprs = new_solutions + + elif isinstance(new_solutions, Intersection): + if isinstance(new_solutions.args[1], FiniteSet): + new_exprs = new_solutions.args[1] + + for new_expr in new_exprs: + result += ImageSet(Lambda(dummy_var, new_expr), base_set) + + if result is S.EmptySet: + return ConditionSet(symbol, Eq(f, 0), domain) + + y_s = result + + return y_s + + +def _solveset(f, symbol, domain, _check=False): + """Helper for solveset to return a result from an expression + that has already been sympify'ed and is known to contain the + given symbol.""" + # _check controls whether the answer is checked or not + from sympy.simplify.simplify import signsimp + + if isinstance(f, BooleanTrue): + return domain + + orig_f = f + if f.is_Mul: + coeff, f = f.as_independent(symbol, as_Add=False) + if coeff in {S.ComplexInfinity, S.NegativeInfinity, S.Infinity}: + f = together(orig_f) + elif f.is_Add: + a, h = f.as_independent(symbol) + m, h = h.as_independent(symbol, as_Add=False) + if m not in {S.ComplexInfinity, S.Zero, S.Infinity, + S.NegativeInfinity}: + f = a/m + h # XXX condition `m != 0` should be added to soln + + # assign the solvers to use + solver = lambda f, x, domain=domain: _solveset(f, x, domain) + inverter = lambda f, rhs, symbol: _invert(f, rhs, symbol, domain) + + result = S.EmptySet + + if f.expand().is_zero: + return domain + elif not f.has(symbol): + return S.EmptySet + elif f.is_Mul and all(_is_finite_with_finite_vars(m, domain) + for m in f.args): + # if f(x) and g(x) are both finite we can say that the solution of + # f(x)*g(x) == 0 is same as Union(f(x) == 0, g(x) == 0) is not true in + # general. g(x) can grow to infinitely large for the values where + # f(x) == 0. To be sure that we are not silently allowing any + # wrong solutions we are using this technique only if both f and g are + # finite for a finite input. + result = Union(*[solver(m, symbol) for m in f.args]) + elif (_is_function_class_equation(TrigonometricFunction, f, symbol) or \ + _is_function_class_equation(HyperbolicFunction, f, symbol)): + result = _solve_trig(f, symbol, domain) + elif isinstance(f, arg): + a = f.args[0] + result = Intersection(_solveset(re(a) > 0, symbol, domain), + _solveset(im(a), symbol, domain)) + elif f.is_Piecewise: + expr_set_pairs = f.as_expr_set_pairs(domain) + for (expr, in_set) in expr_set_pairs: + if in_set.is_Relational: + in_set = in_set.as_set() + solns = solver(expr, symbol, in_set) + result += solns + elif isinstance(f, Eq): + result = solver(Add(f.lhs, -f.rhs, evaluate=False), symbol, domain) + + elif f.is_Relational: + from .inequalities import solve_univariate_inequality + try: + result = solve_univariate_inequality( + f, symbol, domain=domain, relational=False) + except NotImplementedError: + result = ConditionSet(symbol, f, domain) + return result + elif _is_modular(f, symbol): + result = _solve_modular(f, symbol, domain) + else: + lhs, rhs_s = inverter(f, 0, symbol) + if lhs == symbol: + # do some very minimal simplification since + # repeated inversion may have left the result + # in a state that other solvers (e.g. poly) + # would have simplified; this is done here + # rather than in the inverter since here it + # is only done once whereas there it would + # be repeated for each step of the inversion + if isinstance(rhs_s, FiniteSet): + rhs_s = FiniteSet(*[Mul(* + signsimp(i).as_content_primitive()) + for i in rhs_s]) + result = rhs_s + + elif isinstance(rhs_s, FiniteSet): + for equation in [lhs - rhs for rhs in rhs_s]: + if equation == f: + u = unrad(f, symbol) + if u: + result += _solve_radical(equation, u, + symbol, + solver) + elif equation.has(Abs): + result += _solve_abs(f, symbol, domain) + else: + result_rational = _solve_as_rational(equation, symbol, domain) + if not isinstance(result_rational, ConditionSet): + result += result_rational + else: + # may be a transcendental type equation + t_result = _transolve(equation, symbol, domain) + if isinstance(t_result, ConditionSet): + # might need factoring; this is expensive so we + # have delayed until now. To avoid recursion + # errors look for a non-trivial factoring into + # a product of symbol dependent terms; I think + # that something that factors as a Pow would + # have already been recognized by now. + factored = equation.factor() + if factored.is_Mul and equation != factored: + _, dep = factored.as_independent(symbol) + if not dep.is_Add: + # non-trivial factoring of equation + # but use form with constants + # in case they need special handling + t_results = [] + for fac in Mul.make_args(factored): + if fac.has(symbol): + t_results.append(solver(fac, symbol)) + t_result = Union(*t_results) + result += t_result + else: + result += solver(equation, symbol) + + elif rhs_s is not S.EmptySet: + result = ConditionSet(symbol, Eq(f, 0), domain) + + if isinstance(result, ConditionSet): + if isinstance(f, Expr): + num, den = f.as_numer_denom() + if den.has(symbol): + _result = _solveset(num, symbol, domain) + if not isinstance(_result, ConditionSet): + singularities = _solveset(den, symbol, domain) + result = _result - singularities + + if _check: + if isinstance(result, ConditionSet): + # it wasn't solved or has enumerated all conditions + # -- leave it alone + return result + + # whittle away all but the symbol-containing core + # to use this for testing + if isinstance(orig_f, Expr): + fx = orig_f.as_independent(symbol, as_Add=True)[1] + fx = fx.as_independent(symbol, as_Add=False)[1] + else: + fx = orig_f + + if isinstance(result, FiniteSet): + # check the result for invalid solutions + result = FiniteSet(*[s for s in result + if isinstance(s, RootOf) + or domain_check(fx, symbol, s)]) + + return result + + +def _is_modular(f, symbol): + """ + Helper function to check below mentioned types of modular equations. + ``A - Mod(B, C) = 0`` + + A -> This can or cannot be a function of symbol. + B -> This is surely a function of symbol. + C -> It is an integer. + + Parameters + ========== + + f : Expr + The equation to be checked. + + symbol : Symbol + The concerned variable for which the equation is to be checked. + + Examples + ======== + + >>> from sympy import symbols, exp, Mod + >>> from sympy.solvers.solveset import _is_modular as check + >>> x, y = symbols('x y') + >>> check(Mod(x, 3) - 1, x) + True + >>> check(Mod(x, 3) - 1, y) + False + >>> check(Mod(x, 3)**2 - 5, x) + False + >>> check(Mod(x, 3)**2 - y, x) + False + >>> check(exp(Mod(x, 3)) - 1, x) + False + >>> check(Mod(3, y) - 1, y) + False + """ + + if not f.has(Mod): + return False + + # extract modterms from f. + modterms = list(f.atoms(Mod)) + + return (len(modterms) == 1 and # only one Mod should be present + modterms[0].args[0].has(symbol) and # B-> function of symbol + modterms[0].args[1].is_integer and # C-> to be an integer. + any(isinstance(term, Mod) + for term in list(_term_factors(f))) # free from other funcs + ) + + +def _invert_modular(modterm, rhs, n, symbol): + """ + Helper function to invert modular equation. + ``Mod(a, m) - rhs = 0`` + + Generally it is inverted as (a, ImageSet(Lambda(n, m*n + rhs), S.Integers)). + More simplified form will be returned if possible. + + If it is not invertible then (modterm, rhs) is returned. + + The following cases arise while inverting equation ``Mod(a, m) - rhs = 0``: + + 1. If a is symbol then m*n + rhs is the required solution. + + 2. If a is an instance of ``Add`` then we try to find two symbol independent + parts of a and the symbol independent part gets transferred to the other + side and again the ``_invert_modular`` is called on the symbol + dependent part. + + 3. If a is an instance of ``Mul`` then same as we done in ``Add`` we separate + out the symbol dependent and symbol independent parts and transfer the + symbol independent part to the rhs with the help of invert and again the + ``_invert_modular`` is called on the symbol dependent part. + + 4. If a is an instance of ``Pow`` then two cases arise as following: + + - If a is of type (symbol_indep)**(symbol_dep) then the remainder is + evaluated with the help of discrete_log function and then the least + period is being found out with the help of totient function. + period*n + remainder is the required solution in this case. + For reference: (https://en.wikipedia.org/wiki/Euler's_theorem) + + - If a is of type (symbol_dep)**(symbol_indep) then we try to find all + primitive solutions list with the help of nthroot_mod function. + m*n + rem is the general solution where rem belongs to solutions list + from nthroot_mod function. + + Parameters + ========== + + modterm, rhs : Expr + The modular equation to be inverted, ``modterm - rhs = 0`` + + symbol : Symbol + The variable in the equation to be inverted. + + n : Dummy + Dummy variable for output g_n. + + Returns + ======= + + A tuple (f_x, g_n) is being returned where f_x is modular independent function + of symbol and g_n being set of values f_x can have. + + Examples + ======== + + >>> from sympy import symbols, exp, Mod, Dummy, S + >>> from sympy.solvers.solveset import _invert_modular as invert_modular + >>> x, y = symbols('x y') + >>> n = Dummy('n') + >>> invert_modular(Mod(exp(x), 7), S(5), n, x) + (Mod(exp(x), 7), 5) + >>> invert_modular(Mod(x, 7), S(5), n, x) + (x, ImageSet(Lambda(_n, 7*_n + 5), Integers)) + >>> invert_modular(Mod(3*x + 8, 7), S(5), n, x) + (x, ImageSet(Lambda(_n, 7*_n + 6), Integers)) + >>> invert_modular(Mod(x**4, 7), S(5), n, x) + (x, EmptySet) + >>> invert_modular(Mod(2**(x**2 + x + 1), 7), S(2), n, x) + (x**2 + x + 1, ImageSet(Lambda(_n, 3*_n + 1), Naturals0)) + + """ + a, m = modterm.args + + if rhs.is_integer is False: + return symbol, S.EmptySet + + if rhs.is_real is False or any(term.is_real is False + for term in list(_term_factors(a))): + # Check for complex arguments + return modterm, rhs + + if abs(rhs) >= abs(m): + # if rhs has value greater than value of m. + return symbol, S.EmptySet + + if a == symbol: + return symbol, ImageSet(Lambda(n, m*n + rhs), S.Integers) + + if a.is_Add: + # g + h = a + g, h = a.as_independent(symbol) + if g is not S.Zero: + x_indep_term = rhs - Mod(g, m) + return _invert_modular(Mod(h, m), Mod(x_indep_term, m), n, symbol) + + if a.is_Mul: + # g*h = a + g, h = a.as_independent(symbol) + if g is not S.One: + x_indep_term = rhs*invert(g, m) + return _invert_modular(Mod(h, m), Mod(x_indep_term, m), n, symbol) + + if a.is_Pow: + # base**expo = a + base, expo = a.args + if expo.has(symbol) and not base.has(symbol): + # remainder -> solution independent of n of equation. + # m, rhs are made coprime by dividing number_gcd(m, rhs) + if not m.is_Integer and rhs.is_Integer and a.base.is_Integer: + return modterm, rhs + + mdiv = m.p // number_gcd(m.p, rhs.p) + try: + remainder = discrete_log(mdiv, rhs.p, a.base.p) + except ValueError: # log does not exist + return modterm, rhs + # period -> coefficient of n in the solution and also referred as + # the least period of expo in which it is repeats itself. + # (a**(totient(m)) - 1) divides m. Here is link of theorem: + # (https://en.wikipedia.org/wiki/Euler's_theorem) + period = totient(m) + for p in divisors(period): + # there might a lesser period exist than totient(m). + if pow(a.base, p, m / number_gcd(m.p, a.base.p)) == 1: + period = p + break + # recursion is not applied here since _invert_modular is currently + # not smart enough to handle infinite rhs as here expo has infinite + # rhs = ImageSet(Lambda(n, period*n + remainder), S.Naturals0). + return expo, ImageSet(Lambda(n, period*n + remainder), S.Naturals0) + elif base.has(symbol) and not expo.has(symbol): + try: + remainder_list = nthroot_mod(rhs, expo, m, all_roots=True) + if remainder_list == []: + return symbol, S.EmptySet + except (ValueError, NotImplementedError): + return modterm, rhs + g_n = S.EmptySet + for rem in remainder_list: + g_n += ImageSet(Lambda(n, m*n + rem), S.Integers) + return base, g_n + + return modterm, rhs + + +def _solve_modular(f, symbol, domain): + r""" + Helper function for solving modular equations of type ``A - Mod(B, C) = 0``, + where A can or cannot be a function of symbol, B is surely a function of + symbol and C is an integer. + + Currently ``_solve_modular`` is only able to solve cases + where A is not a function of symbol. + + Parameters + ========== + + f : Expr + The modular equation to be solved, ``f = 0`` + + symbol : Symbol + The variable in the equation to be solved. + + domain : Set + A set over which the equation is solved. It has to be a subset of + Integers. + + Returns + ======= + + A set of integer solutions satisfying the given modular equation. + A ``ConditionSet`` if the equation is unsolvable. + + Examples + ======== + + >>> from sympy.solvers.solveset import _solve_modular as solve_modulo + >>> from sympy import S, Symbol, sin, Intersection, Interval, Mod + >>> x = Symbol('x') + >>> solve_modulo(Mod(5*x - 8, 7) - 3, x, S.Integers) + ImageSet(Lambda(_n, 7*_n + 5), Integers) + >>> solve_modulo(Mod(5*x - 8, 7) - 3, x, S.Reals) # domain should be subset of integers. + ConditionSet(x, Eq(Mod(5*x + 6, 7) - 3, 0), Reals) + >>> solve_modulo(-7 + Mod(x, 5), x, S.Integers) + EmptySet + >>> solve_modulo(Mod(12**x, 21) - 18, x, S.Integers) + ImageSet(Lambda(_n, 6*_n + 2), Naturals0) + >>> solve_modulo(Mod(sin(x), 7) - 3, x, S.Integers) # not solvable + ConditionSet(x, Eq(Mod(sin(x), 7) - 3, 0), Integers) + >>> solve_modulo(3 - Mod(x, 5), x, Intersection(S.Integers, Interval(0, 100))) + Intersection(ImageSet(Lambda(_n, 5*_n + 3), Integers), Range(0, 101, 1)) + """ + # extract modterm and g_y from f + unsolved_result = ConditionSet(symbol, Eq(f, 0), domain) + modterm = list(f.atoms(Mod))[0] + rhs = -S.One*(f.subs(modterm, S.Zero)) + if f.as_coefficients_dict()[modterm].is_negative: + # checks if coefficient of modterm is negative in main equation. + rhs *= -S.One + + if not domain.is_subset(S.Integers): + return unsolved_result + + if rhs.has(symbol): + # TODO Case: A-> function of symbol, can be extended here + # in future. + return unsolved_result + + n = Dummy('n', integer=True) + f_x, g_n = _invert_modular(modterm, rhs, n, symbol) + + if f_x == modterm and g_n == rhs: + return unsolved_result + + if f_x == symbol: + if domain is not S.Integers: + return domain.intersect(g_n) + return g_n + + if isinstance(g_n, ImageSet): + lamda_expr = g_n.lamda.expr + lamda_vars = g_n.lamda.variables + base_sets = g_n.base_sets + sol_set = _solveset(f_x - lamda_expr, symbol, S.Integers) + if isinstance(sol_set, FiniteSet): + tmp_sol = S.EmptySet + for sol in sol_set: + tmp_sol += ImageSet(Lambda(lamda_vars, sol), *base_sets) + sol_set = tmp_sol + else: + sol_set = ImageSet(Lambda(lamda_vars, sol_set), *base_sets) + return domain.intersect(sol_set) + + return unsolved_result + + +def _term_factors(f): + """ + Iterator to get the factors of all terms present + in the given equation. + + Parameters + ========== + f : Expr + Equation that needs to be addressed + + Returns + ======= + Factors of all terms present in the equation. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.solvers.solveset import _term_factors + >>> x = symbols('x') + >>> list(_term_factors(-2 - x**2 + x*(x + 1))) + [-2, -1, x**2, x, x + 1] + """ + for add_arg in Add.make_args(f): + yield from Mul.make_args(add_arg) + + +def _solve_exponential(lhs, rhs, symbol, domain): + r""" + Helper function for solving (supported) exponential equations. + + Exponential equations are the sum of (currently) at most + two terms with one or both of them having a power with a + symbol-dependent exponent. + + For example + + .. math:: 5^{2x + 3} - 5^{3x - 1} + + .. math:: 4^{5 - 9x} - e^{2 - x} + + Parameters + ========== + + lhs, rhs : Expr + The exponential equation to be solved, `lhs = rhs` + + symbol : Symbol + The variable in which the equation is solved + + domain : Set + A set over which the equation is solved. + + Returns + ======= + + A set of solutions satisfying the given equation. + A ``ConditionSet`` if the equation is unsolvable or + if the assumptions are not properly defined, in that case + a different style of ``ConditionSet`` is returned having the + solution(s) of the equation with the desired assumptions. + + Examples + ======== + + >>> from sympy.solvers.solveset import _solve_exponential as solve_expo + >>> from sympy import symbols, S + >>> x = symbols('x', real=True) + >>> a, b = symbols('a b') + >>> solve_expo(2**x + 3**x - 5**x, 0, x, S.Reals) # not solvable + ConditionSet(x, Eq(2**x + 3**x - 5**x, 0), Reals) + >>> solve_expo(a**x - b**x, 0, x, S.Reals) # solvable but incorrect assumptions + ConditionSet(x, (a > 0) & (b > 0), {0}) + >>> solve_expo(3**(2*x) - 2**(x + 3), 0, x, S.Reals) + {-3*log(2)/(-2*log(3) + log(2))} + >>> solve_expo(2**x - 4**x, 0, x, S.Reals) + {0} + + * Proof of correctness of the method + + The logarithm function is the inverse of the exponential function. + The defining relation between exponentiation and logarithm is: + + .. math:: {\log_b x} = y \enspace if \enspace b^y = x + + Therefore if we are given an equation with exponent terms, we can + convert every term to its corresponding logarithmic form. This is + achieved by taking logarithms and expanding the equation using + logarithmic identities so that it can easily be handled by ``solveset``. + + For example: + + .. math:: 3^{2x} = 2^{x + 3} + + Taking log both sides will reduce the equation to + + .. math:: (2x)\log(3) = (x + 3)\log(2) + + This form can be easily handed by ``solveset``. + """ + unsolved_result = ConditionSet(symbol, Eq(lhs - rhs, 0), domain) + newlhs = powdenest(lhs) + if lhs != newlhs: + # it may also be advantageous to factor the new expr + neweq = factor(newlhs - rhs) + if neweq != (lhs - rhs): + return _solveset(neweq, symbol, domain) # try again with _solveset + + if not (isinstance(lhs, Add) and len(lhs.args) == 2): + # solving for the sum of more than two powers is possible + # but not yet implemented + return unsolved_result + + if rhs != 0: + return unsolved_result + + a, b = list(ordered(lhs.args)) + a_term = a.as_independent(symbol)[1] + b_term = b.as_independent(symbol)[1] + + a_base, a_exp = a_term.as_base_exp() + b_base, b_exp = b_term.as_base_exp() + + if domain.is_subset(S.Reals): + conditions = And( + a_base > 0, + b_base > 0, + Eq(im(a_exp), 0), + Eq(im(b_exp), 0)) + else: + conditions = And( + Ne(a_base, 0), + Ne(b_base, 0)) + + L, R = (expand_log(log(i), force=True) for i in (a, -b)) + solutions = _solveset(L - R, symbol, domain) + + return ConditionSet(symbol, conditions, solutions) + + +def _is_exponential(f, symbol): + r""" + Return ``True`` if one or more terms contain ``symbol`` only in + exponents, else ``False``. + + Parameters + ========== + + f : Expr + The equation to be checked + + symbol : Symbol + The variable in which the equation is checked + + Examples + ======== + + >>> from sympy import symbols, cos, exp + >>> from sympy.solvers.solveset import _is_exponential as check + >>> x, y = symbols('x y') + >>> check(y, y) + False + >>> check(x**y - 1, y) + True + >>> check(x**y*2**y - 1, y) + True + >>> check(exp(x + 3) + 3**x, x) + True + >>> check(cos(2**x), x) + False + + * Philosophy behind the helper + + The function extracts each term of the equation and checks if it is + of exponential form w.r.t ``symbol``. + """ + rv = False + for expr_arg in _term_factors(f): + if symbol not in expr_arg.free_symbols: + continue + if (isinstance(expr_arg, Pow) and + symbol not in expr_arg.base.free_symbols or + isinstance(expr_arg, exp)): + rv = True # symbol in exponent + else: + return False # dependent on symbol in non-exponential way + return rv + + +def _solve_logarithm(lhs, rhs, symbol, domain): + r""" + Helper to solve logarithmic equations which are reducible + to a single instance of `\log`. + + Logarithmic equations are (currently) the equations that contains + `\log` terms which can be reduced to a single `\log` term or + a constant using various logarithmic identities. + + For example: + + .. math:: \log(x) + \log(x - 4) + + can be reduced to: + + .. math:: \log(x(x - 4)) + + Parameters + ========== + + lhs, rhs : Expr + The logarithmic equation to be solved, `lhs = rhs` + + symbol : Symbol + The variable in which the equation is solved + + domain : Set + A set over which the equation is solved. + + Returns + ======= + + A set of solutions satisfying the given equation. + A ``ConditionSet`` if the equation is unsolvable. + + Examples + ======== + + >>> from sympy import symbols, log, S + >>> from sympy.solvers.solveset import _solve_logarithm as solve_log + >>> x = symbols('x') + >>> f = log(x - 3) + log(x + 3) + >>> solve_log(f, 0, x, S.Reals) + {-sqrt(10), sqrt(10)} + + * Proof of correctness + + A logarithm is another way to write exponent and is defined by + + .. math:: {\log_b x} = y \enspace if \enspace b^y = x + + When one side of the equation contains a single logarithm, the + equation can be solved by rewriting the equation as an equivalent + exponential equation as defined above. But if one side contains + more than one logarithm, we need to use the properties of logarithm + to condense it into a single logarithm. + + Take for example + + .. math:: \log(2x) - 15 = 0 + + contains single logarithm, therefore we can directly rewrite it to + exponential form as + + .. math:: x = \frac{e^{15}}{2} + + But if the equation has more than one logarithm as + + .. math:: \log(x - 3) + \log(x + 3) = 0 + + we use logarithmic identities to convert it into a reduced form + + Using, + + .. math:: \log(a) + \log(b) = \log(ab) + + the equation becomes, + + .. math:: \log((x - 3)(x + 3)) + + This equation contains one logarithm and can be solved by rewriting + to exponents. + """ + new_lhs = logcombine(lhs, force=True) + new_f = new_lhs - rhs + + return _solveset(new_f, symbol, domain) + + +def _is_logarithmic(f, symbol): + r""" + Return ``True`` if the equation is in the form + `a\log(f(x)) + b\log(g(x)) + ... + c` else ``False``. + + Parameters + ========== + + f : Expr + The equation to be checked + + symbol : Symbol + The variable in which the equation is checked + + Returns + ======= + + ``True`` if the equation is logarithmic otherwise ``False``. + + Examples + ======== + + >>> from sympy import symbols, tan, log + >>> from sympy.solvers.solveset import _is_logarithmic as check + >>> x, y = symbols('x y') + >>> check(log(x + 2) - log(x + 3), x) + True + >>> check(tan(log(2*x)), x) + False + >>> check(x*log(x), x) + False + >>> check(x + log(x), x) + False + >>> check(y + log(x), x) + True + + * Philosophy behind the helper + + The function extracts each term and checks whether it is + logarithmic w.r.t ``symbol``. + """ + rv = False + for term in Add.make_args(f): + saw_log = False + for term_arg in Mul.make_args(term): + if symbol not in term_arg.free_symbols: + continue + if isinstance(term_arg, log): + if saw_log: + return False # more than one log in term + saw_log = True + else: + return False # dependent on symbol in non-log way + if saw_log: + rv = True + return rv + + +def _is_lambert(f, symbol): + r""" + If this returns ``False`` then the Lambert solver (``_solve_lambert``) will not be called. + + Explanation + =========== + + Quick check for cases that the Lambert solver might be able to handle. + + 1. Equations containing more than two operands and `symbol`s involving any of + `Pow`, `exp`, `HyperbolicFunction`,`TrigonometricFunction`, `log` terms. + + 2. In `Pow`, `exp` the exponent should have `symbol` whereas for + `HyperbolicFunction`,`TrigonometricFunction`, `log` should contain `symbol`. + + 3. For `HyperbolicFunction`,`TrigonometricFunction` the number of trigonometric functions in + equation should be less than number of symbols. (since `A*cos(x) + B*sin(x) - c` + is not the Lambert type). + + Some forms of lambert equations are: + 1. X**X = C + 2. X*(B*log(X) + D)**A = C + 3. A*log(B*X + A) + d*X = C + 4. (B*X + A)*exp(d*X + g) = C + 5. g*exp(B*X + h) - B*X = C + 6. A*D**(E*X + g) - B*X = C + 7. A*cos(X) + B*sin(X) - D*X = C + 8. A*cosh(X) + B*sinh(X) - D*X = C + + Where X is any variable, + A, B, C, D, E are any constants, + g, h are linear functions or log terms. + + Parameters + ========== + + f : Expr + The equation to be checked + + symbol : Symbol + The variable in which the equation is checked + + Returns + ======= + + If this returns ``False`` then the Lambert solver (``_solve_lambert``) will not be called. + + Examples + ======== + + >>> from sympy.solvers.solveset import _is_lambert + >>> from sympy import symbols, cosh, sinh, log + >>> x = symbols('x') + + >>> _is_lambert(3*log(x) - x*log(3), x) + True + >>> _is_lambert(log(log(x - 3)) + log(x-3), x) + True + >>> _is_lambert(cosh(x) - sinh(x), x) + False + >>> _is_lambert((x**2 - 2*x + 1).subs(x, (log(x) + 3*x)**2 - 1), x) + True + + See Also + ======== + + _solve_lambert + + """ + term_factors = list(_term_factors(f.expand())) + + # total number of symbols in equation + no_of_symbols = len([arg for arg in term_factors if arg.has(symbol)]) + # total number of trigonometric terms in equation + no_of_trig = len([arg for arg in term_factors \ + if arg.has(HyperbolicFunction, TrigonometricFunction)]) + + if f.is_Add and no_of_symbols >= 2: + # `log`, `HyperbolicFunction`, `TrigonometricFunction` should have symbols + # and no_of_trig < no_of_symbols + lambert_funcs = (log, HyperbolicFunction, TrigonometricFunction) + if any(isinstance(arg, lambert_funcs)\ + for arg in term_factors if arg.has(symbol)): + if no_of_trig < no_of_symbols: + return True + # here, `Pow`, `exp` exponent should have symbols + elif any(isinstance(arg, (Pow, exp)) \ + for arg in term_factors if (arg.as_base_exp()[1]).has(symbol)): + return True + return False + + +def _transolve(f, symbol, domain): + r""" + Function to solve transcendental equations. It is a helper to + ``solveset`` and should be used internally. ``_transolve`` + currently supports the following class of equations: + + - Exponential equations + - Logarithmic equations + + Parameters + ========== + + f : Any transcendental equation that needs to be solved. + This needs to be an expression, which is assumed + to be equal to ``0``. + + symbol : The variable for which the equation is solved. + This needs to be of class ``Symbol``. + + domain : A set over which the equation is solved. + This needs to be of class ``Set``. + + Returns + ======= + + Set + A set of values for ``symbol`` for which ``f`` is equal to + zero. An ``EmptySet`` is returned if ``f`` does not have solutions + in respective domain. A ``ConditionSet`` is returned as unsolved + object if algorithms to evaluate complete solution are not + yet implemented. + + How to use ``_transolve`` + ========================= + + ``_transolve`` should not be used as an independent function, because + it assumes that the equation (``f``) and the ``symbol`` comes from + ``solveset`` and might have undergone a few modification(s). + To use ``_transolve`` as an independent function the equation (``f``) + and the ``symbol`` should be passed as they would have been by + ``solveset``. + + Examples + ======== + + >>> from sympy.solvers.solveset import _transolve as transolve + >>> from sympy.solvers.solvers import _tsolve as tsolve + >>> from sympy import symbols, S, pprint + >>> x = symbols('x', real=True) # assumption added + >>> transolve(5**(x - 3) - 3**(2*x + 1), x, S.Reals) + {-(log(3) + 3*log(5))/(-log(5) + 2*log(3))} + + How ``_transolve`` works + ======================== + + ``_transolve`` uses two types of helper functions to solve equations + of a particular class: + + Identifying helpers: To determine whether a given equation + belongs to a certain class of equation or not. Returns either + ``True`` or ``False``. + + Solving helpers: Once an equation is identified, a corresponding + helper either solves the equation or returns a form of the equation + that ``solveset`` might better be able to handle. + + * Philosophy behind the module + + The purpose of ``_transolve`` is to take equations which are not + already polynomial in their generator(s) and to either recast them + as such through a valid transformation or to solve them outright. + A pair of helper functions for each class of supported + transcendental functions are employed for this purpose. One + identifies the transcendental form of an equation and the other + either solves it or recasts it into a tractable form that can be + solved by ``solveset``. + For example, an equation in the form `ab^{f(x)} - cd^{g(x)} = 0` + can be transformed to + `\log(a) + f(x)\log(b) - \log(c) - g(x)\log(d) = 0` + (under certain assumptions) and this can be solved with ``solveset`` + if `f(x)` and `g(x)` are in polynomial form. + + How ``_transolve`` is better than ``_tsolve`` + ============================================= + + 1) Better output + + ``_transolve`` provides expressions in a more simplified form. + + Consider a simple exponential equation + + >>> f = 3**(2*x) - 2**(x + 3) + >>> pprint(transolve(f, x, S.Reals), use_unicode=False) + -3*log(2) + {------------------} + -2*log(3) + log(2) + >>> pprint(tsolve(f, x), use_unicode=False) + / 3 \ + | --------| + | log(2/9)| + [-log\2 /] + + 2) Extensible + + The API of ``_transolve`` is designed such that it is easily + extensible, i.e. the code that solves a given class of + equations is encapsulated in a helper and not mixed in with + the code of ``_transolve`` itself. + + 3) Modular + + ``_transolve`` is designed to be modular i.e, for every class of + equation a separate helper for identification and solving is + implemented. This makes it easy to change or modify any of the + method implemented directly in the helpers without interfering + with the actual structure of the API. + + 4) Faster Computation + + Solving equation via ``_transolve`` is much faster as compared to + ``_tsolve``. In ``solve``, attempts are made computing every possibility + to get the solutions. This series of attempts makes solving a bit + slow. In ``_transolve``, computation begins only after a particular + type of equation is identified. + + How to add new class of equations + ================================= + + Adding a new class of equation solver is a three-step procedure: + + - Identify the type of the equations + + Determine the type of the class of equations to which they belong: + it could be of ``Add``, ``Pow``, etc. types. Separate internal functions + are used for each type. Write identification and solving helpers + and use them from within the routine for the given type of equation + (after adding it, if necessary). Something like: + + .. code-block:: python + + def add_type(lhs, rhs, x): + .... + if _is_exponential(lhs, x): + new_eq = _solve_exponential(lhs, rhs, x) + .... + rhs, lhs = eq.as_independent(x) + if lhs.is_Add: + result = add_type(lhs, rhs, x) + + - Define the identification helper. + + - Define the solving helper. + + Apart from this, a few other things needs to be taken care while + adding an equation solver: + + - Naming conventions: + Name of the identification helper should be as + ``_is_class`` where class will be the name or abbreviation + of the class of equation. The solving helper will be named as + ``_solve_class``. + For example: for exponential equations it becomes + ``_is_exponential`` and ``_solve_expo``. + - The identifying helpers should take two input parameters, + the equation to be checked and the variable for which a solution + is being sought, while solving helpers would require an additional + domain parameter. + - Be sure to consider corner cases. + - Add tests for each helper. + - Add a docstring to your helper that describes the method + implemented. + The documentation of the helpers should identify: + + - the purpose of the helper, + - the method used to identify and solve the equation, + - a proof of correctness + - the return values of the helpers + """ + + def add_type(lhs, rhs, symbol, domain): + """ + Helper for ``_transolve`` to handle equations of + ``Add`` type, i.e. equations taking the form as + ``a*f(x) + b*g(x) + .... = c``. + For example: 4**x + 8**x = 0 + """ + result = ConditionSet(symbol, Eq(lhs - rhs, 0), domain) + + # check if it is exponential type equation + if _is_exponential(lhs, symbol): + result = _solve_exponential(lhs, rhs, symbol, domain) + # check if it is logarithmic type equation + elif _is_logarithmic(lhs, symbol): + result = _solve_logarithm(lhs, rhs, symbol, domain) + + return result + + result = ConditionSet(symbol, Eq(f, 0), domain) + + # invert_complex handles the call to the desired inverter based + # on the domain specified. + lhs, rhs_s = invert_complex(f, 0, symbol, domain) + + if isinstance(rhs_s, FiniteSet): + assert (len(rhs_s.args)) == 1 + rhs = rhs_s.args[0] + + if lhs.is_Add: + result = add_type(lhs, rhs, symbol, domain) + else: + result = rhs_s + + return result + + +def solveset(f, symbol=None, domain=S.Complexes): + r"""Solves a given inequality or equation with set as output + + Parameters + ========== + + f : Expr or a relational. + The target equation or inequality + symbol : Symbol + The variable for which the equation is solved + domain : Set + The domain over which the equation is solved + + Returns + ======= + + Set + A set of values for `symbol` for which `f` is True or is equal to + zero. An :class:`~.EmptySet` is returned if `f` is False or nonzero. + A :class:`~.ConditionSet` is returned as unsolved object if algorithms + to evaluate complete solution are not yet implemented. + + ``solveset`` claims to be complete in the solution set that it returns. + + Raises + ====== + + NotImplementedError + The algorithms to solve inequalities in complex domain are + not yet implemented. + ValueError + The input is not valid. + RuntimeError + It is a bug, please report to the github issue tracker. + + + Notes + ===== + + Python interprets 0 and 1 as False and True, respectively, but + in this function they refer to solutions of an expression. So 0 and 1 + return the domain and EmptySet, respectively, while True and False + return the opposite (as they are assumed to be solutions of relational + expressions). + + + See Also + ======== + + solveset_real: solver for real domain + solveset_complex: solver for complex domain + + Examples + ======== + + >>> from sympy import exp, sin, Symbol, pprint, S, Eq + >>> from sympy.solvers.solveset import solveset, solveset_real + + * The default domain is complex. Not specifying a domain will lead + to the solving of the equation in the complex domain (and this + is not affected by the assumptions on the symbol): + + >>> x = Symbol('x') + >>> pprint(solveset(exp(x) - 1, x), use_unicode=False) + {2*n*I*pi | n in Integers} + + >>> x = Symbol('x', real=True) + >>> pprint(solveset(exp(x) - 1, x), use_unicode=False) + {2*n*I*pi | n in Integers} + + * If you want to use ``solveset`` to solve the equation in the + real domain, provide a real domain. (Using ``solveset_real`` + does this automatically.) + + >>> R = S.Reals + >>> x = Symbol('x') + >>> solveset(exp(x) - 1, x, R) + {0} + >>> solveset_real(exp(x) - 1, x) + {0} + + The solution is unaffected by assumptions on the symbol: + + >>> p = Symbol('p', positive=True) + >>> pprint(solveset(p**2 - 4)) + {-2, 2} + + When a :class:`~.ConditionSet` is returned, symbols with assumptions that + would alter the set are replaced with more generic symbols: + + >>> i = Symbol('i', imaginary=True) + >>> solveset(Eq(i**2 + i*sin(i), 1), i, domain=S.Reals) + ConditionSet(_R, Eq(_R**2 + _R*sin(_R) - 1, 0), Reals) + + * Inequalities can be solved over the real domain only. Use of a complex + domain leads to a NotImplementedError. + + >>> solveset(exp(x) > 1, x, R) + Interval.open(0, oo) + + """ + f = sympify(f) + symbol = sympify(symbol) + + if f is S.true: + return domain + + if f is S.false: + return S.EmptySet + + if not isinstance(f, (Expr, Relational, Number)): + raise ValueError("%s is not a valid SymPy expression" % f) + + if not isinstance(symbol, (Expr, Relational)) and symbol is not None: + raise ValueError("%s is not a valid SymPy symbol" % (symbol,)) + + if not isinstance(domain, Set): + raise ValueError("%s is not a valid domain" %(domain)) + + free_symbols = f.free_symbols + + if f.has(Piecewise): + f = piecewise_fold(f) + + if symbol is None and not free_symbols: + b = Eq(f, 0) + if b is S.true: + return domain + elif b is S.false: + return S.EmptySet + else: + raise NotImplementedError(filldedent(''' + relationship between value and 0 is unknown: %s''' % b)) + + if symbol is None: + if len(free_symbols) == 1: + symbol = free_symbols.pop() + elif free_symbols: + raise ValueError(filldedent(''' + The independent variable must be specified for a + multivariate equation.''')) + elif not isinstance(symbol, Symbol): + f, s, swap = recast_to_symbols([f], [symbol]) + # the xreplace will be needed if a ConditionSet is returned + return solveset(f[0], s[0], domain).xreplace(swap) + + # solveset should ignore assumptions on symbols + newsym = None + if domain.is_subset(S.Reals): + if symbol._assumptions_orig != {'real': True}: + newsym = Dummy('R', real=True) + elif domain.is_subset(S.Complexes): + if symbol._assumptions_orig != {'complex': True}: + newsym = Dummy('C', complex=True) + + if newsym is not None: + rv = solveset(f.xreplace({symbol: newsym}), newsym, domain) + # try to use the original symbol if possible + try: + _rv = rv.xreplace({newsym: symbol}) + except TypeError: + _rv = rv + if rv.dummy_eq(_rv): + rv = _rv + return rv + + # Abs has its own handling method which avoids the + # rewriting property that the first piece of abs(x) + # is for x >= 0 and the 2nd piece for x < 0 -- solutions + # can look better if the 2nd condition is x <= 0. Since + # the solution is a set, duplication of results is not + # an issue, e.g. {y, -y} when y is 0 will be {0} + f, mask = _masked(f, Abs) + f = f.rewrite(Piecewise) # everything that's not an Abs + for d, e in mask: + # everything *in* an Abs + e = e.func(e.args[0].rewrite(Piecewise)) + f = f.xreplace({d: e}) + f = piecewise_fold(f) + + return _solveset(f, symbol, domain, _check=True) + + +def solveset_real(f, symbol): + return solveset(f, symbol, S.Reals) + + +def solveset_complex(f, symbol): + return solveset(f, symbol, S.Complexes) + + +def _solveset_multi(eqs, syms, domains): + '''Basic implementation of a multivariate solveset. + + For internal use (not ready for public consumption)''' + + rep = {} + for sym, dom in zip(syms, domains): + if dom is S.Reals: + rep[sym] = Symbol(sym.name, real=True) + eqs = [eq.subs(rep) for eq in eqs] + syms = [sym.subs(rep) for sym in syms] + + syms = tuple(syms) + + if len(eqs) == 0: + return ProductSet(*domains) + + if len(syms) == 1: + sym = syms[0] + domain = domains[0] + solsets = [solveset(eq, sym, domain) for eq in eqs] + solset = Intersection(*solsets) + return ImageSet(Lambda((sym,), (sym,)), solset).doit() + + eqs = sorted(eqs, key=lambda eq: len(eq.free_symbols & set(syms))) + + for n, eq in enumerate(eqs): + sols = [] + all_handled = True + for sym in syms: + if sym not in eq.free_symbols: + continue + sol = solveset(eq, sym, domains[syms.index(sym)]) + + if isinstance(sol, FiniteSet): + i = syms.index(sym) + symsp = syms[:i] + syms[i+1:] + domainsp = domains[:i] + domains[i+1:] + eqsp = eqs[:n] + eqs[n+1:] + for s in sol: + eqsp_sub = [eq.subs(sym, s) for eq in eqsp] + sol_others = _solveset_multi(eqsp_sub, symsp, domainsp) + fun = Lambda((symsp,), symsp[:i] + (s,) + symsp[i:]) + sols.append(ImageSet(fun, sol_others).doit()) + else: + all_handled = False + if all_handled: + return Union(*sols) + + +def solvify(f, symbol, domain): + """Solves an equation using solveset and returns the solution in accordance + with the `solve` output API. + + Returns + ======= + + We classify the output based on the type of solution returned by `solveset`. + + Solution | Output + ---------------------------------------- + FiniteSet | list + + ImageSet, | list (if `f` is periodic) + Union | + + Union | list (with FiniteSet) + + EmptySet | empty list + + Others | None + + + Raises + ====== + + NotImplementedError + A ConditionSet is the input. + + Examples + ======== + + >>> from sympy.solvers.solveset import solvify + >>> from sympy.abc import x + >>> from sympy import S, tan, sin, exp + >>> solvify(x**2 - 9, x, S.Reals) + [-3, 3] + >>> solvify(sin(x) - 1, x, S.Reals) + [pi/2] + >>> solvify(tan(x), x, S.Reals) + [0] + >>> solvify(exp(x) - 1, x, S.Complexes) + + >>> solvify(exp(x) - 1, x, S.Reals) + [0] + + """ + solution_set = solveset(f, symbol, domain) + result = None + if solution_set is S.EmptySet: + result = [] + + elif isinstance(solution_set, ConditionSet): + raise NotImplementedError('solveset is unable to solve this equation.') + + elif isinstance(solution_set, FiniteSet): + result = list(solution_set) + + else: + period = periodicity(f, symbol) + if period is not None: + solutions = S.EmptySet + iter_solutions = () + if isinstance(solution_set, ImageSet): + iter_solutions = (solution_set,) + elif isinstance(solution_set, Union): + if all(isinstance(i, ImageSet) for i in solution_set.args): + iter_solutions = solution_set.args + + for solution in iter_solutions: + solutions += solution.intersect(Interval(0, period, False, True)) + + if isinstance(solutions, FiniteSet): + result = list(solutions) + + else: + solution = solution_set.intersect(domain) + if isinstance(solution, Union): + # concerned about only FiniteSet with Union but not about ImageSet + # if required could be extend + if any(isinstance(i, FiniteSet) for i in solution.args): + result = [sol for soln in solution.args \ + for sol in soln.args if isinstance(soln,FiniteSet)] + else: + return None + + elif isinstance(solution, FiniteSet): + result += solution + + return result + + +############################################################################### +################################ LINSOLVE ##################################### +############################################################################### + + +def linear_coeffs(eq, *syms, dict=False): + """Return a list whose elements are the coefficients of the + corresponding symbols in the sum of terms in ``eq``. + The additive constant is returned as the last element of the + list. + + Raises + ====== + + NonlinearError + The equation contains a nonlinear term + ValueError + duplicate or unordered symbols are passed + + Parameters + ========== + + dict - (default False) when True, return coefficients as a + dictionary with coefficients keyed to syms that were present; + key 1 gives the constant term + + Examples + ======== + + >>> from sympy.solvers.solveset import linear_coeffs + >>> from sympy.abc import x, y, z + >>> linear_coeffs(3*x + 2*y - 1, x, y) + [3, 2, -1] + + It is not necessary to expand the expression: + + >>> linear_coeffs(x + y*(z*(x*3 + 2) + 3), x) + [3*y*z + 1, y*(2*z + 3)] + + When nonlinear is detected, an error will be raised: + + * even if they would cancel after expansion (so the + situation does not pass silently past the caller's + attention) + + >>> eq = 1/x*(x - 1) + 1/x + >>> linear_coeffs(eq.expand(), x) + [0, 1] + >>> linear_coeffs(eq, x) + Traceback (most recent call last): + ... + NonlinearError: + nonlinear in given generators + + * when there are cross terms + + >>> linear_coeffs(x*(y + 1), x, y) + Traceback (most recent call last): + ... + NonlinearError: + symbol-dependent cross-terms encountered + + * when there are terms that contain an expression + dependent on the symbols that is not linear + + >>> linear_coeffs(x**2, x) + Traceback (most recent call last): + ... + NonlinearError: + nonlinear in given generators + """ + eq = _sympify(eq) + if len(syms) == 1 and iterable(syms[0]) and not isinstance(syms[0], Basic): + raise ValueError('expecting unpacked symbols, *syms') + symset = set(syms) + if len(symset) != len(syms): + raise ValueError('duplicate symbols given') + try: + d, c = _linear_eq_to_dict([eq], symset) + d = d[0] + c = c[0] + except PolyNonlinearError as err: + raise NonlinearError(str(err)) + if dict: + if c: + d[S.One] = c + return d + rv = [S.Zero]*(len(syms) + 1) + rv[-1] = c + for i, k in enumerate(syms): + if k not in d: + continue + rv[i] = d[k] + return rv + + +def linear_eq_to_matrix(equations, *symbols): + r""" + Converts a given System of Equations into Matrix form. Here ``equations`` + must be a linear system of equations in ``symbols``. Element ``M[i, j]`` + corresponds to the coefficient of the jth symbol in the ith equation. + + The Matrix form corresponds to the augmented matrix form. For example: + + .. math:: + + 4x + 2y + 3z & = 1 \\ + 3x + y + z & = -6 \\ + 2x + 4y + 9z & = 2 + + This system will return :math:`A` and :math:`b` as: + + .. math:: + + A = \left[\begin{array}{ccc} + 4 & 2 & 3 \\ + 3 & 1 & 1 \\ + 2 & 4 & 9 + \end{array}\right] \\ + + .. math:: + + b = \left[\begin{array}{c} + 1 \\ -6 \\ 2 + \end{array}\right] + + The only simplification performed is to convert + ``Eq(a, b)`` :math:`\Rightarrow a - b`. + + Raises + ====== + + NonlinearError + The equations contain a nonlinear term. + ValueError + The symbols are not given or are not unique. + + Examples + ======== + + >>> from sympy import linear_eq_to_matrix, symbols + >>> c, x, y, z = symbols('c, x, y, z') + + The coefficients (numerical or symbolic) of the symbols will + be returned as matrices: + + >>> eqns = [c*x + z - 1 - c, y + z, x - y] + >>> A, b = linear_eq_to_matrix(eqns, [x, y, z]) + >>> A + Matrix([ + [c, 0, 1], + [0, 1, 1], + [1, -1, 0]]) + >>> b + Matrix([ + [c + 1], + [ 0], + [ 0]]) + + This routine does not simplify expressions and will raise an error + if nonlinearity is encountered: + + >>> eqns = [ + ... (x**2 - 3*x)/(x - 3) - 3, + ... y**2 - 3*y - y*(y - 4) + x - 4] + >>> linear_eq_to_matrix(eqns, [x, y]) + Traceback (most recent call last): + ... + NonlinearError: + symbol-dependent term can be ignored using `strict=False` + + Simplifying these equations will discard the removable singularity in the + first and reveal the linear structure of the second: + + >>> [e.simplify() for e in eqns] + [x - 3, x + y - 4] + + Any such simplification needed to eliminate nonlinear terms must be done + *before* calling this routine. + + """ + if not symbols: + raise ValueError(filldedent(''' + Symbols must be given, for which coefficients + are to be found. + ''')) + + # Check if 'symbols' is a set and raise an error if it is + if isinstance(symbols[0], set): + raise TypeError( + "Unordered 'set' type is not supported as input for symbols.") + + if hasattr(symbols[0], '__iter__'): + symbols = symbols[0] + + if has_dups(symbols): + raise ValueError('Symbols must be unique') + + equations = sympify(equations) + if isinstance(equations, MatrixBase): + equations = list(equations) + elif isinstance(equations, (Expr, Eq)): + equations = [equations] + elif not is_sequence(equations): + raise ValueError(filldedent(''' + Equation(s) must be given as a sequence, Expr, + Eq or Matrix. + ''')) + + # construct the dictionaries + try: + eq, c = _linear_eq_to_dict(equations, symbols) + except PolyNonlinearError as err: + raise NonlinearError(str(err)) + # prepare output matrices + n, m = shape = len(eq), len(symbols) + ix = dict(zip(symbols, range(m))) + A = zeros(*shape) + for row, d in enumerate(eq): + for k in d: + col = ix[k] + A[row, col] = d[k] + b = Matrix(n, 1, [-i for i in c]) + return A, b + + +def linsolve(system, *symbols): + r""" + Solve system of $N$ linear equations with $M$ variables; both + underdetermined and overdetermined systems are supported. + The possible number of solutions is zero, one or infinite. + Zero solutions throws a ValueError, whereas infinite + solutions are represented parametrically in terms of the given + symbols. For unique solution a :class:`~.FiniteSet` of ordered tuples + is returned. + + All standard input formats are supported: + For the given set of equations, the respective input types + are given below: + + .. math:: 3x + 2y - z = 1 + .. math:: 2x - 2y + 4z = -2 + .. math:: 2x - y + 2z = 0 + + * Augmented matrix form, ``system`` given below: + + $$ \text{system} = \left[{array}{cccc} + 3 & 2 & -1 & 1\\ + 2 & -2 & 4 & -2\\ + 2 & -1 & 2 & 0 + \end{array}\right] $$ + + :: + + system = Matrix([[3, 2, -1, 1], [2, -2, 4, -2], [2, -1, 2, 0]]) + + * List of equations form + + :: + + system = [3x + 2y - z - 1, 2x - 2y + 4z + 2, 2x - y + 2z] + + * Input $A$ and $b$ in matrix form (from $Ax = b$) are given as: + + $$ A = \left[\begin{array}{ccc} + 3 & 2 & -1 \\ + 2 & -2 & 4 \\ + 2 & -1 & 2 + \end{array}\right] \ \ b = \left[\begin{array}{c} + 1 \\ -2 \\ 0 + \end{array}\right] $$ + + :: + + A = Matrix([[3, 2, -1], [2, -2, 4], [2, -1, 2]]) + b = Matrix([[1], [-2], [0]]) + system = (A, b) + + Symbols can always be passed but are actually only needed + when 1) a system of equations is being passed and 2) the + system is passed as an underdetermined matrix and one wants + to control the name of the free variables in the result. + An error is raised if no symbols are used for case 1, but if + no symbols are provided for case 2, internally generated symbols + will be provided. When providing symbols for case 2, there should + be at least as many symbols are there are columns in matrix A. + + The algorithm used here is Gauss-Jordan elimination, which + results, after elimination, in a row echelon form matrix. + + Returns + ======= + + A FiniteSet containing an ordered tuple of values for the + unknowns for which the `system` has a solution. (Wrapping + the tuple in FiniteSet is used to maintain a consistent + output format throughout solveset.) + + Returns EmptySet, if the linear system is inconsistent. + + Raises + ====== + + ValueError + The input is not valid. + The symbols are not given. + + Examples + ======== + + >>> from sympy import Matrix, linsolve, symbols + >>> x, y, z = symbols("x, y, z") + >>> A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 10]]) + >>> b = Matrix([3, 6, 9]) + >>> A + Matrix([ + [1, 2, 3], + [4, 5, 6], + [7, 8, 10]]) + >>> b + Matrix([ + [3], + [6], + [9]]) + >>> linsolve((A, b), [x, y, z]) + {(-1, 2, 0)} + + * Parametric Solution: In case the system is underdetermined, the + function will return a parametric solution in terms of the given + symbols. Those that are free will be returned unchanged. e.g. in + the system below, `z` is returned as the solution for variable z; + it can take on any value. + + >>> A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> b = Matrix([3, 6, 9]) + >>> linsolve((A, b), x, y, z) + {(z - 1, 2 - 2*z, z)} + + If no symbols are given, internally generated symbols will be used. + The ``tau0`` in the third position indicates (as before) that the third + variable -- whatever it is named -- can take on any value: + + >>> linsolve((A, b)) + {(tau0 - 1, 2 - 2*tau0, tau0)} + + * List of equations as input + + >>> Eqns = [3*x + 2*y - z - 1, 2*x - 2*y + 4*z + 2, - x + y/2 - z] + >>> linsolve(Eqns, x, y, z) + {(1, -2, -2)} + + * Augmented matrix as input + + >>> aug = Matrix([[2, 1, 3, 1], [2, 6, 8, 3], [6, 8, 18, 5]]) + >>> aug + Matrix([ + [2, 1, 3, 1], + [2, 6, 8, 3], + [6, 8, 18, 5]]) + >>> linsolve(aug, x, y, z) + {(3/10, 2/5, 0)} + + * Solve for symbolic coefficients + + >>> a, b, c, d, e, f = symbols('a, b, c, d, e, f') + >>> eqns = [a*x + b*y - c, d*x + e*y - f] + >>> linsolve(eqns, x, y) + {((-b*f + c*e)/(a*e - b*d), (a*f - c*d)/(a*e - b*d))} + + * A degenerate system returns solution as set of given + symbols. + + >>> system = Matrix(([0, 0, 0], [0, 0, 0], [0, 0, 0])) + >>> linsolve(system, x, y) + {(x, y)} + + * For an empty system linsolve returns empty set + + >>> linsolve([], x) + EmptySet + + * An error is raised if any nonlinearity is detected, even + if it could be removed with expansion + + >>> linsolve([x*(1/x - 1)], x) + Traceback (most recent call last): + ... + NonlinearError: nonlinear term: 1/x + + >>> linsolve([x*(y + 1)], x, y) + Traceback (most recent call last): + ... + NonlinearError: nonlinear cross-term: x*(y + 1) + + >>> linsolve([x**2 - 1], x) + Traceback (most recent call last): + ... + NonlinearError: nonlinear term: x**2 + """ + if not system: + return S.EmptySet + + # If second argument is an iterable + if symbols and hasattr(symbols[0], '__iter__'): + symbols = symbols[0] + sym_gen = isinstance(symbols, GeneratorType) + dup_msg = 'duplicate symbols given' + + + b = None # if we don't get b the input was bad + # unpack system + + if hasattr(system, '__iter__'): + + # 1). (A, b) + if len(system) == 2 and isinstance(system[0], MatrixBase): + A, b = system + + # 2). (eq1, eq2, ...) + if not isinstance(system[0], MatrixBase): + if sym_gen or not symbols: + raise ValueError(filldedent(''' + When passing a system of equations, the explicit + symbols for which a solution is being sought must + be given as a sequence, too. + ''')) + if len(set(symbols)) != len(symbols): + raise ValueError(dup_msg) + + # + # Pass to the sparse solver implemented in polys. It is important + # that we do not attempt to convert the equations to a matrix + # because that would be very inefficient for large sparse systems + # of equations. + # + eqs = system + eqs = [sympify(eq) for eq in eqs] + try: + sol = _linsolve(eqs, symbols) + except PolyNonlinearError as exc: + # e.g. cos(x) contains an element of the set of generators + raise NonlinearError(str(exc)) + + if sol is None: + return S.EmptySet + + sol = FiniteSet(Tuple(*(sol.get(sym, sym) for sym in symbols))) + return sol + + elif isinstance(system, MatrixBase) and not ( + symbols and not isinstance(symbols, GeneratorType) and + isinstance(symbols[0], MatrixBase)): + # 3). A augmented with b + A, b = system[:, :-1], system[:, -1:] + + if b is None: + raise ValueError("Invalid arguments") + if sym_gen: + symbols = [next(symbols) for i in range(A.cols)] + symset = set(symbols) + if any(symset & (A.free_symbols | b.free_symbols)): + raise ValueError(filldedent(''' + At least one of the symbols provided + already appears in the system to be solved. + One way to avoid this is to use Dummy symbols in + the generator, e.g. numbered_symbols('%s', cls=Dummy) + ''' % symbols[0].name.rstrip('1234567890'))) + elif len(symset) != len(symbols): + raise ValueError(dup_msg) + + if not symbols: + symbols = [Dummy() for _ in range(A.cols)] + name = _uniquely_named_symbol('tau', (A, b), + compare=lambda i: str(i).rstrip('1234567890')).name + gen = numbered_symbols(name) + else: + gen = None + + # This is just a wrapper for solve_lin_sys + eqs = [] + rows = A.tolist() + for rowi, bi in zip(rows, b): + terms = [elem * sym for elem, sym in zip(rowi, symbols) if elem] + terms.append(-bi) + eqs.append(Add(*terms)) + + eqs, ring = sympy_eqs_to_ring(eqs, symbols) + sol = solve_lin_sys(eqs, ring, _raw=False) + if sol is None: + return S.EmptySet + #sol = {sym:val for sym, val in sol.items() if sym != val} + sol = FiniteSet(Tuple(*(sol.get(sym, sym) for sym in symbols))) + + if gen is not None: + solsym = sol.free_symbols + rep = {sym: next(gen) for sym in symbols if sym in solsym} + sol = sol.subs(rep) + + return sol + + +############################################################################## +# ------------------------------nonlinsolve ---------------------------------# +############################################################################## + + +def _return_conditionset(eqs, symbols): + # return conditionset + eqs = (Eq(lhs, 0) for lhs in eqs) + condition_set = ConditionSet( + Tuple(*symbols), And(*eqs), S.Complexes**len(symbols)) + return condition_set + + +def substitution(system, symbols, result=[{}], known_symbols=[], + exclude=[], all_symbols=None): + r""" + Solves the `system` using substitution method. It is used in + :func:`~.nonlinsolve`. This will be called from :func:`~.nonlinsolve` when any + equation(s) is non polynomial equation. + + Parameters + ========== + + system : list of equations + The target system of equations + symbols : list of symbols to be solved. + The variable(s) for which the system is solved + known_symbols : list of solved symbols + Values are known for these variable(s) + result : An empty list or list of dict + If No symbol values is known then empty list otherwise + symbol as keys and corresponding value in dict. + exclude : Set of expression. + Mostly denominator expression(s) of the equations of the system. + Final solution should not satisfy these expressions. + all_symbols : known_symbols + symbols(unsolved). + + Returns + ======= + + A FiniteSet of ordered tuple of values of `all_symbols` for which the + `system` has solution. Order of values in the tuple is same as symbols + present in the parameter `all_symbols`. If parameter `all_symbols` is None + then same as symbols present in the parameter `symbols`. + + Please note that general FiniteSet is unordered, the solution returned + here is not simply a FiniteSet of solutions, rather it is a FiniteSet of + ordered tuple, i.e. the first & only argument to FiniteSet is a tuple of + solutions, which is ordered, & hence the returned solution is ordered. + + Also note that solution could also have been returned as an ordered tuple, + FiniteSet is just a wrapper `{}` around the tuple. It has no other + significance except for the fact it is just used to maintain a consistent + output format throughout the solveset. + + Raises + ====== + + ValueError + The input is not valid. + The symbols are not given. + AttributeError + The input symbols are not :class:`~.Symbol` type. + + Examples + ======== + + >>> from sympy import symbols, substitution + >>> x, y = symbols('x, y', real=True) + >>> substitution([x + y], [x], [{y: 1}], [y], set([]), [x, y]) + {(-1, 1)} + + * When you want a soln not satisfying $x + 1 = 0$ + + >>> substitution([x + y], [x], [{y: 1}], [y], set([x + 1]), [y, x]) + EmptySet + >>> substitution([x + y], [x], [{y: 1}], [y], set([x - 1]), [y, x]) + {(1, -1)} + >>> substitution([x + y - 1, y - x**2 + 5], [x, y]) + {(-3, 4), (2, -1)} + + * Returns both real and complex solution + + >>> x, y, z = symbols('x, y, z') + >>> from sympy import exp, sin + >>> substitution([exp(x) - sin(y), y**2 - 4], [x, y]) + {(ImageSet(Lambda(_n, I*(2*_n*pi + pi) + log(sin(2))), Integers), -2), + (ImageSet(Lambda(_n, 2*_n*I*pi + log(sin(2))), Integers), 2)} + + >>> eqs = [z**2 + exp(2*x) - sin(y), -3 + exp(-y)] + >>> substitution(eqs, [y, z]) + {(-log(3), -sqrt(-exp(2*x) - sin(log(3)))), + (-log(3), sqrt(-exp(2*x) - sin(log(3)))), + (ImageSet(Lambda(_n, 2*_n*I*pi - log(3)), Integers), + ImageSet(Lambda(_n, -sqrt(-exp(2*x) + sin(2*_n*I*pi - log(3)))), Integers)), + (ImageSet(Lambda(_n, 2*_n*I*pi - log(3)), Integers), + ImageSet(Lambda(_n, sqrt(-exp(2*x) + sin(2*_n*I*pi - log(3)))), Integers))} + + """ + + if not system: + return S.EmptySet + + for i, e in enumerate(system): + if isinstance(e, Eq): + system[i] = e.lhs - e.rhs + + if not symbols: + msg = ('Symbols must be given, for which solution of the ' + 'system is to be found.') + raise ValueError(filldedent(msg)) + + if not is_sequence(symbols): + msg = ('symbols should be given as a sequence, e.g. a list.' + 'Not type %s: %s') + raise TypeError(filldedent(msg % (type(symbols), symbols))) + + if not getattr(symbols[0], 'is_Symbol', False): + msg = ('Iterable of symbols must be given as ' + 'second argument, not type %s: %s') + raise ValueError(filldedent(msg % (type(symbols[0]), symbols[0]))) + + # By default `all_symbols` will be same as `symbols` + if all_symbols is None: + all_symbols = symbols + + old_result = result + # storing complements and intersection for particular symbol + complements = {} + intersections = {} + + # when total_solveset_call equals total_conditionset + # it means that solveset failed to solve all eqs. + total_conditionset = -1 + total_solveset_call = -1 + + def _unsolved_syms(eq, sort=False): + """Returns the unsolved symbol present + in the equation `eq`. + """ + free = eq.free_symbols + unsolved = (free - set(known_symbols)) & set(all_symbols) + if sort: + unsolved = list(unsolved) + unsolved.sort(key=default_sort_key) + return unsolved + + # sort such that equation with the fewest potential symbols is first. + # means eq with less number of variable first in the list. + eqs_in_better_order = list( + ordered(system, lambda _: len(_unsolved_syms(_)))) + + def add_intersection_complement(result, intersection_dict, complement_dict): + # If solveset has returned some intersection/complement + # for any symbol, it will be added in the final solution. + final_result = [] + for res in result: + res_copy = res + for key_res, value_res in res.items(): + intersect_set, complement_set = None, None + for key_sym, value_sym in intersection_dict.items(): + if key_sym == key_res: + intersect_set = value_sym + for key_sym, value_sym in complement_dict.items(): + if key_sym == key_res: + complement_set = value_sym + if intersect_set or complement_set: + new_value = FiniteSet(value_res) + if intersect_set and intersect_set != S.Complexes: + new_value = Intersection(new_value, intersect_set) + if complement_set: + new_value = Complement(new_value, complement_set) + if new_value is S.EmptySet: + res_copy = None + break + elif new_value.is_FiniteSet and len(new_value) == 1: + res_copy[key_res] = set(new_value).pop() + else: + res_copy[key_res] = new_value + + if res_copy is not None: + final_result.append(res_copy) + return final_result + + def _extract_main_soln(sym, sol, soln_imageset): + """Separate the Complements, Intersections, ImageSet lambda expr and + its base_set. This function returns the unmasked sol from different classes + of sets and also returns the appended ImageSet elements in a + soln_imageset dict: `{unmasked element: ImageSet}`. + """ + # if there is union, then need to check + # Complement, Intersection, Imageset. + # Order should not be changed. + if isinstance(sol, ConditionSet): + # extracts any solution in ConditionSet + sol = sol.base_set + + if isinstance(sol, Complement): + # extract solution and complement + complements[sym] = sol.args[1] + sol = sol.args[0] + # complement will be added at the end + # using `add_intersection_complement` method + + # if there is union of Imageset or other in soln. + # no testcase is written for this if block + if isinstance(sol, Union): + sol_args = sol.args + sol = S.EmptySet + # We need in sequence so append finteset elements + # and then imageset or other. + for sol_arg2 in sol_args: + if isinstance(sol_arg2, FiniteSet): + sol += sol_arg2 + else: + # ImageSet, Intersection, complement then + # append them directly + sol += FiniteSet(sol_arg2) + + if isinstance(sol, Intersection): + # Interval/Set will be at 0th index always + if sol.args[0] not in (S.Reals, S.Complexes): + # Sometimes solveset returns soln with intersection + # S.Reals or S.Complexes. We don't consider that + # intersection. + intersections[sym] = sol.args[0] + sol = sol.args[1] + # after intersection and complement Imageset should + # be checked. + if isinstance(sol, ImageSet): + soln_imagest = sol + expr2 = sol.lamda.expr + sol = FiniteSet(expr2) + soln_imageset[expr2] = soln_imagest + + if not isinstance(sol, FiniteSet): + sol = FiniteSet(sol) + return sol, soln_imageset + + def _check_exclude(rnew, imgset_yes): + rnew_ = rnew + if imgset_yes: + # replace all dummy variables (Imageset lambda variables) + # with zero before `checksol`. Considering fundamental soln + # for `checksol`. + rnew_copy = rnew.copy() + dummy_n = imgset_yes[0] + for key_res, value_res in rnew_copy.items(): + rnew_copy[key_res] = value_res.subs(dummy_n, 0) + rnew_ = rnew_copy + # satisfy_exclude == true if it satisfies the expr of `exclude` list. + try: + # something like : `Mod(-log(3), 2*I*pi)` can't be + # simplified right now, so `checksol` returns `TypeError`. + # when this issue is fixed this try block should be + # removed. Mod(-log(3), 2*I*pi) == -log(3) + satisfy_exclude = any( + checksol(d, rnew_) for d in exclude) + except TypeError: + satisfy_exclude = None + return satisfy_exclude + + def _restore_imgset(rnew, original_imageset, newresult): + restore_sym = set(rnew.keys()) & \ + set(original_imageset.keys()) + for key_sym in restore_sym: + img = original_imageset[key_sym] + rnew[key_sym] = img + if rnew not in newresult: + newresult.append(rnew) + + def _append_eq(eq, result, res, delete_soln, n=None): + u = Dummy('u') + if n: + eq = eq.subs(n, 0) + satisfy = eq if eq in (True, False) else checksol(u, u, eq, minimal=True) + if satisfy is False: + delete_soln = True + res = {} + else: + result.append(res) + return result, res, delete_soln + + def _append_new_soln(rnew, sym, sol, imgset_yes, soln_imageset, + original_imageset, newresult, eq=None): + """If `rnew` (A dict ) contains valid soln + append it to `newresult` list. + `imgset_yes` is (base, dummy_var) if there was imageset in previously + calculated result(otherwise empty tuple). `original_imageset` is dict + of imageset expr and imageset from this result. + `soln_imageset` dict of imageset expr and imageset of new soln. + """ + satisfy_exclude = _check_exclude(rnew, imgset_yes) + delete_soln = False + # soln should not satisfy expr present in `exclude` list. + if not satisfy_exclude: + local_n = None + # if it is imageset + if imgset_yes: + local_n = imgset_yes[0] + base = imgset_yes[1] + if sym and sol: + # when `sym` and `sol` is `None` means no new + # soln. In that case we will append rnew directly after + # substituting original imagesets in rnew values if present + # (second last line of this function using _restore_imgset) + dummy_list = list(sol.atoms(Dummy)) + # use one dummy `n` which is in + # previous imageset + local_n_list = [ + local_n for i in range( + 0, len(dummy_list))] + + dummy_zip = zip(dummy_list, local_n_list) + lam = Lambda(local_n, sol.subs(dummy_zip)) + rnew[sym] = ImageSet(lam, base) + if eq is not None: + newresult, rnew, delete_soln = _append_eq( + eq, newresult, rnew, delete_soln, local_n) + elif eq is not None: + newresult, rnew, delete_soln = _append_eq( + eq, newresult, rnew, delete_soln) + elif sol in soln_imageset.keys(): + rnew[sym] = soln_imageset[sol] + # restore original imageset + _restore_imgset(rnew, original_imageset, newresult) + else: + newresult.append(rnew) + elif satisfy_exclude: + delete_soln = True + rnew = {} + _restore_imgset(rnew, original_imageset, newresult) + return newresult, delete_soln + + def _new_order_result(result, eq): + # separate first, second priority. `res` that makes `eq` value equals + # to zero, should be used first then other result(second priority). + # If it is not done then we may miss some soln. + first_priority = [] + second_priority = [] + for res in result: + if not any(isinstance(val, ImageSet) for val in res.values()): + if eq.subs(res) == 0: + first_priority.append(res) + else: + second_priority.append(res) + if first_priority or second_priority: + return first_priority + second_priority + return result + + def _solve_using_known_values(result, solver): + """Solves the system using already known solution + (result contains the dict ). + solver is :func:`~.solveset_complex` or :func:`~.solveset_real`. + """ + # stores imageset . + soln_imageset = {} + total_solvest_call = 0 + total_conditionst = 0 + + # sort equations so the one with the fewest potential + # symbols appears first + for index, eq in enumerate(eqs_in_better_order): + newresult = [] + # if imageset, expr is used to solve for other symbol + imgset_yes = False + for res in result: + original_imageset = {} + got_symbol = set() # symbols solved in one iteration + # find the imageset and use its expr. + for k, v in res.items(): + if isinstance(v, ImageSet): + res[k] = v.lamda.expr + original_imageset[k] = v + dummy_n = v.lamda.expr.atoms(Dummy).pop() + (base,) = v.base_sets + imgset_yes = (dummy_n, base) + assert not isinstance(v, FiniteSet) # if so, internal error + # update eq with everything that is known so far + eq2 = eq.subs(res).expand() + if imgset_yes and not eq2.has(imgset_yes[0]): + # The substituted equation simplified in such a way that + # it's no longer necessary to encapsulate a potential new + # solution in an ImageSet. (E.g. at the previous step some + # {n*2*pi} was found as partial solution for one of the + # unknowns, but its main solution expression n*2*pi has now + # been substituted in a trigonometric function.) + imgset_yes = False + + unsolved_syms = _unsolved_syms(eq2, sort=True) + if not unsolved_syms: + if res: + newresult, delete_res = _append_new_soln( + res, None, None, imgset_yes, soln_imageset, + original_imageset, newresult, eq2) + if delete_res: + # `delete_res` is true, means substituting `res` in + # eq2 doesn't return `zero` or deleting the `res` + # (a soln) since it satisfies expr of `exclude` + # list. + result.remove(res) + continue # skip as it's independent of desired symbols + depen1, depen2 = eq2.as_independent(*unsolved_syms) + if (depen1.has(Abs) or depen2.has(Abs)) and solver == solveset_complex: + # Absolute values cannot be inverted in the + # complex domain + continue + soln_imageset = {} + for sym in unsolved_syms: + not_solvable = False + try: + soln = solver(eq2, sym) + total_solvest_call += 1 + soln_new = S.EmptySet + if isinstance(soln, Complement): + # separate solution and complement + complements[sym] = soln.args[1] + soln = soln.args[0] + # complement will be added at the end + if isinstance(soln, Intersection): + # Interval will be at 0th index always + if soln.args[0] != Interval(-oo, oo): + # sometimes solveset returns soln + # with intersection S.Reals, to confirm that + # soln is in domain=S.Reals + intersections[sym] = soln.args[0] + soln_new += soln.args[1] + soln = soln_new if soln_new else soln + if index > 0 and solver == solveset_real: + # one symbol's real soln, another symbol may have + # corresponding complex soln. + if not isinstance(soln, (ImageSet, ConditionSet)): + soln += solveset_complex(eq2, sym) # might give ValueError with Abs + except (NotImplementedError, ValueError): + # If solveset is not able to solve equation `eq2`. Next + # time we may get soln using next equation `eq2` + continue + if isinstance(soln, ConditionSet): + if soln.base_set in (S.Reals, S.Complexes): + soln = S.EmptySet + # don't do `continue` we may get soln + # in terms of other symbol(s) + not_solvable = True + total_conditionst += 1 + else: + soln = soln.base_set + + if soln is not S.EmptySet: + soln, soln_imageset = _extract_main_soln( + sym, soln, soln_imageset) + + for sol in soln: + # sol is not a `Union` since we checked it + # before this loop + sol, soln_imageset = _extract_main_soln( + sym, sol, soln_imageset) + sol = set(sol).pop() # XXX what if there are more solutions? + free = sol.free_symbols + if got_symbol and any( + ss in free for ss in got_symbol + ): + # sol depends on previously solved symbols + # then continue + continue + rnew = res.copy() + # put each solution in res and append the new result + # in the new result list (solution for symbol `s`) + # along with old results. + for k, v in res.items(): + if isinstance(v, Expr) and isinstance(sol, Expr): + # if any unsolved symbol is present + # Then subs known value + rnew[k] = v.subs(sym, sol) + # and add this new solution + if sol in soln_imageset.keys(): + # replace all lambda variables with 0. + imgst = soln_imageset[sol] + rnew[sym] = imgst.lamda( + *[0 for i in range(0, len( + imgst.lamda.variables))]) + else: + rnew[sym] = sol + newresult, delete_res = _append_new_soln( + rnew, sym, sol, imgset_yes, soln_imageset, + original_imageset, newresult) + if delete_res: + # deleting the `res` (a soln) since it satisfies + # eq of `exclude` list + result.remove(res) + # solution got for sym + if not not_solvable: + got_symbol.add(sym) + # next time use this new soln + if newresult: + result = newresult + return result, total_solvest_call, total_conditionst + + new_result_real, solve_call1, cnd_call1 = _solve_using_known_values( + old_result, solveset_real) + new_result_complex, solve_call2, cnd_call2 = _solve_using_known_values( + old_result, solveset_complex) + + # If total_solveset_call is equal to total_conditionset + # then solveset failed to solve all of the equations. + # In this case we return a ConditionSet here. + total_conditionset += (cnd_call1 + cnd_call2) + total_solveset_call += (solve_call1 + solve_call2) + + if total_conditionset == total_solveset_call and total_solveset_call != -1: + return _return_conditionset(eqs_in_better_order, all_symbols) + + # don't keep duplicate solutions + filtered_complex = [] + for i in list(new_result_complex): + for j in list(new_result_real): + if i.keys() != j.keys(): + continue + if all(a.dummy_eq(b) for a, b in zip(i.values(), j.values()) \ + if not (isinstance(a, int) and isinstance(b, int))): + break + else: + filtered_complex.append(i) + # overall result + result = new_result_real + filtered_complex + + result_all_variables = [] + result_infinite = [] + for res in result: + if not res: + # means {None : None} + continue + # If length < len(all_symbols) means infinite soln. + # Some or all the soln is dependent on 1 symbol. + # eg. {x: y+2} then final soln {x: y+2, y: y} + if len(res) < len(all_symbols): + solved_symbols = res.keys() + unsolved = list(filter( + lambda x: x not in solved_symbols, all_symbols)) + for unsolved_sym in unsolved: + res[unsolved_sym] = unsolved_sym + result_infinite.append(res) + if res not in result_all_variables: + result_all_variables.append(res) + + if result_infinite: + # we have general soln + # eg : [{x: -1, y : 1}, {x : -y, y: y}] then + # return [{x : -y, y : y}] + result_all_variables = result_infinite + if intersections or complements: + result_all_variables = add_intersection_complement( + result_all_variables, intersections, complements) + + # convert to ordered tuple + result = S.EmptySet + for r in result_all_variables: + temp = [r[symb] for symb in all_symbols] + result += FiniteSet(tuple(temp)) + return result + + +def _solveset_work(system, symbols): + soln = solveset(system[0], symbols[0]) + if isinstance(soln, FiniteSet): + _soln = FiniteSet(*[(s,) for s in soln]) + return _soln + else: + return FiniteSet(tuple(FiniteSet(soln))) + + +def _handle_positive_dimensional(polys, symbols, denominators): + from sympy.polys.polytools import groebner + # substitution method where new system is groebner basis of the system + _symbols = list(symbols) + _symbols.sort(key=default_sort_key) + basis = groebner(polys, _symbols, polys=True) + new_system = [] + for poly_eq in basis: + new_system.append(poly_eq.as_expr()) + result = [{}] + result = substitution( + new_system, symbols, result, [], + denominators) + return result + + +def _handle_zero_dimensional(polys, symbols, system): + # solve 0 dimensional poly system using `solve_poly_system` + result = solve_poly_system(polys, *symbols) + # May be some extra soln is added because + # we used `unrad` in `_separate_poly_nonpoly`, so + # need to check and remove if it is not a soln. + result_update = S.EmptySet + for res in result: + dict_sym_value = dict(list(zip(symbols, res))) + if all(checksol(eq, dict_sym_value) for eq in system): + result_update += FiniteSet(res) + return result_update + + +def _separate_poly_nonpoly(system, symbols): + polys = [] + polys_expr = [] + nonpolys = [] + # unrad_changed stores a list of expressions containing + # radicals that were processed using unrad + # this is useful if solutions need to be checked later. + unrad_changed = [] + denominators = set() + poly = None + for eq in system: + # Store denom expressions that contain symbols + denominators.update(_simple_dens(eq, symbols)) + # Convert equality to expression + if isinstance(eq, Eq): + eq = eq.lhs - eq.rhs + # try to remove sqrt and rational power + without_radicals = unrad(simplify(eq), *symbols) + if without_radicals: + unrad_changed.append(eq) + eq_unrad, cov = without_radicals + if not cov: + eq = eq_unrad + if isinstance(eq, Expr): + eq = eq.as_numer_denom()[0] + poly = eq.as_poly(*symbols, extension=True) + elif simplify(eq).is_number: + continue + if poly is not None: + polys.append(poly) + polys_expr.append(poly.as_expr()) + else: + nonpolys.append(eq) + return polys, polys_expr, nonpolys, denominators, unrad_changed + + +def _handle_poly(polys, symbols): + # _handle_poly(polys, symbols) -> (poly_sol, poly_eqs) + # + # We will return possible solution information to nonlinsolve as well as a + # new system of polynomial equations to be solved if we cannot solve + # everything directly here. The new system of polynomial equations will be + # a lex-order Groebner basis for the original system. The lex basis + # hopefully separate some of the variables and equations and give something + # easier for substitution to work with. + + # The format for representing solution sets in nonlinsolve and substitution + # is a list of dicts. These are the special cases: + no_information = [{}] # No equations solved yet + no_solutions = [] # The system is inconsistent and has no solutions. + + # If there is no need to attempt further solution of these equations then + # we return no equations: + no_equations = [] + + inexact = any(not p.domain.is_Exact for p in polys) + if inexact: + # The use of Groebner over RR is likely to result incorrectly in an + # inconsistent Groebner basis. So, convert any float coefficients to + # Rational before computing the Groebner basis. + polys = [poly(nsimplify(p, rational=True)) for p in polys] + + # Compute a Groebner basis in grevlex order wrt the ordering given. We will + # try to convert this to lex order later. Usually it seems to be more + # efficient to compute a lex order basis by computing a grevlex basis and + # converting to lex with fglm. + basis = groebner(polys, symbols, order='grevlex', polys=False) + + # + # No solutions (inconsistent equations)? + # + if 1 in basis: + + # No solutions: + poly_sol = no_solutions + poly_eqs = no_equations + + # + # Finite number of solutions (zero-dimensional case) + # + elif basis.is_zero_dimensional: + + # Convert Groebner basis to lex ordering + basis = basis.fglm('lex') + + # Convert polynomial coefficients back to float before calling + # solve_poly_system + if inexact: + basis = [nfloat(p) for p in basis] + + # Solve the zero-dimensional case using solve_poly_system if possible. + # If some polynomials have factors that cannot be solved in radicals + # then this will fail. Using solve_poly_system(..., strict=True) + # ensures that we either get a complete solution set in radicals or + # UnsolvableFactorError will be raised. + try: + result = solve_poly_system(basis, *symbols, strict=True) + except UnsolvableFactorError: + # Failure... not fully solvable in radicals. Return the lex-order + # basis for substitution to handle. + poly_sol = no_information + poly_eqs = list(basis) + else: + # Success! We have a finite solution set and solve_poly_system has + # succeeded in finding all solutions. Return the solutions and also + # an empty list of remaining equations to be solved. + poly_sol = [dict(zip(symbols, res)) for res in result] + poly_eqs = no_equations + + # + # Infinite families of solutions (positive-dimensional case) + # + else: + # In this case the grevlex basis cannot be converted to lex using the + # fglm method and also solve_poly_system cannot solve the equations. We + # would like to return a lex basis but since we can't use fglm we + # compute the lex basis directly here. The time required to recompute + # the basis is generally significantly less than the time required by + # substitution to solve the new system. + poly_sol = no_information + poly_eqs = list(groebner(polys, symbols, order='lex', polys=False)) + + if inexact: + poly_eqs = [nfloat(p) for p in poly_eqs] + + return poly_sol, poly_eqs + + +def nonlinsolve(system, *symbols): + r""" + Solve system of $N$ nonlinear equations with $M$ variables, which means both + under and overdetermined systems are supported. Positive dimensional + system is also supported (A system with infinitely many solutions is said + to be positive-dimensional). In a positive dimensional system the solution will + be dependent on at least one symbol. Returns both real solution + and complex solution (if they exist). + + Parameters + ========== + + system : list of equations + The target system of equations + symbols : list of Symbols + symbols should be given as a sequence eg. list + + Returns + ======= + + A :class:`~.FiniteSet` of ordered tuple of values of `symbols` for which the `system` + has solution. Order of values in the tuple is same as symbols present in + the parameter `symbols`. + + Please note that general :class:`~.FiniteSet` is unordered, the solution + returned here is not simply a :class:`~.FiniteSet` of solutions, rather it + is a :class:`~.FiniteSet` of ordered tuple, i.e. the first and only + argument to :class:`~.FiniteSet` is a tuple of solutions, which is + ordered, and, hence ,the returned solution is ordered. + + Also note that solution could also have been returned as an ordered tuple, + FiniteSet is just a wrapper ``{}`` around the tuple. It has no other + significance except for the fact it is just used to maintain a consistent + output format throughout the solveset. + + For the given set of equations, the respective input types + are given below: + + .. math:: xy - 1 = 0 + .. math:: 4x^2 + y^2 - 5 = 0 + + :: + + system = [x*y - 1, 4*x**2 + y**2 - 5] + symbols = [x, y] + + Raises + ====== + + ValueError + The input is not valid. + The symbols are not given. + AttributeError + The input symbols are not `Symbol` type. + + Examples + ======== + + >>> from sympy import symbols, nonlinsolve + >>> x, y, z = symbols('x, y, z', real=True) + >>> nonlinsolve([x*y - 1, 4*x**2 + y**2 - 5], [x, y]) + {(-1, -1), (-1/2, -2), (1/2, 2), (1, 1)} + + 1. Positive dimensional system and complements: + + >>> from sympy import pprint + >>> from sympy.polys.polytools import is_zero_dimensional + >>> a, b, c, d = symbols('a, b, c, d', extended_real=True) + >>> eq1 = a + b + c + d + >>> eq2 = a*b + b*c + c*d + d*a + >>> eq3 = a*b*c + b*c*d + c*d*a + d*a*b + >>> eq4 = a*b*c*d - 1 + >>> system = [eq1, eq2, eq3, eq4] + >>> is_zero_dimensional(system) + False + >>> pprint(nonlinsolve(system, [a, b, c, d]), use_unicode=False) + -1 1 1 -1 + {(---, -d, -, {d} \ {0}), (-, -d, ---, {d} \ {0})} + d d d d + >>> nonlinsolve([(x+y)**2 - 4, x + y - 2], [x, y]) + {(2 - y, y)} + + 2. If some of the equations are non-polynomial then `nonlinsolve` + will call the ``substitution`` function and return real and complex solutions, + if present. + + >>> from sympy import exp, sin + >>> nonlinsolve([exp(x) - sin(y), y**2 - 4], [x, y]) + {(ImageSet(Lambda(_n, I*(2*_n*pi + pi) + log(sin(2))), Integers), -2), + (ImageSet(Lambda(_n, 2*_n*I*pi + log(sin(2))), Integers), 2)} + + 3. If system is non-linear polynomial and zero-dimensional then it + returns both solution (real and complex solutions, if present) using + :func:`~.solve_poly_system`: + + >>> from sympy import sqrt + >>> nonlinsolve([x**2 - 2*y**2 -2, x*y - 2], [x, y]) + {(-2, -1), (2, 1), (-sqrt(2)*I, sqrt(2)*I), (sqrt(2)*I, -sqrt(2)*I)} + + 4. ``nonlinsolve`` can solve some linear (zero or positive dimensional) + system (because it uses the :func:`sympy.polys.polytools.groebner` function to get the + groebner basis and then uses the ``substitution`` function basis as the + new `system`). But it is not recommended to solve linear system using + ``nonlinsolve``, because :func:`~.linsolve` is better for general linear systems. + + >>> nonlinsolve([x + 2*y -z - 3, x - y - 4*z + 9, y + z - 4], [x, y, z]) + {(3*z - 5, 4 - z, z)} + + 5. System having polynomial equations and only real solution is + solved using :func:`~.solve_poly_system`: + + >>> e1 = sqrt(x**2 + y**2) - 10 + >>> e2 = sqrt(y**2 + (-x + 10)**2) - 3 + >>> nonlinsolve((e1, e2), (x, y)) + {(191/20, -3*sqrt(391)/20), (191/20, 3*sqrt(391)/20)} + >>> nonlinsolve([x**2 + 2/y - 2, x + y - 3], [x, y]) + {(1, 2), (1 - sqrt(5), 2 + sqrt(5)), (1 + sqrt(5), 2 - sqrt(5))} + >>> nonlinsolve([x**2 + 2/y - 2, x + y - 3], [y, x]) + {(2, 1), (2 - sqrt(5), 1 + sqrt(5)), (2 + sqrt(5), 1 - sqrt(5))} + + 6. It is better to use symbols instead of trigonometric functions or + :class:`~.Function`. For example, replace $\sin(x)$ with a symbol, replace + $f(x)$ with a symbol and so on. Get a solution from ``nonlinsolve`` and then + use :func:`~.solveset` to get the value of $x$. + + How nonlinsolve is better than old solver ``_solve_system`` : + ============================================================= + + 1. A positive dimensional system solver: nonlinsolve can return + solution for positive dimensional system. It finds the + Groebner Basis of the positive dimensional system(calling it as + basis) then we can start solving equation(having least number of + variable first in the basis) using solveset and substituting that + solved solutions into other equation(of basis) to get solution in + terms of minimum variables. Here the important thing is how we + are substituting the known values and in which equations. + + 2. Real and complex solutions: nonlinsolve returns both real + and complex solution. If all the equations in the system are polynomial + then using :func:`~.solve_poly_system` both real and complex solution is returned. + If all the equations in the system are not polynomial equation then goes to + ``substitution`` method with this polynomial and non polynomial equation(s), + to solve for unsolved variables. Here to solve for particular variable + solveset_real and solveset_complex is used. For both real and complex + solution ``_solve_using_known_values`` is used inside ``substitution`` + (``substitution`` will be called when any non-polynomial equation is present). + If a solution is valid its general solution is added to the final result. + + 3. :class:`~.Complement` and :class:`~.Intersection` will be added: + nonlinsolve maintains dict for complements and intersections. If solveset + find complements or/and intersections with any interval or set during the + execution of ``substitution`` function, then complement or/and + intersection for that variable is added before returning final solution. + + """ + if not system: + return S.EmptySet + + if not symbols: + msg = ('Symbols must be given, for which solution of the ' + 'system is to be found.') + raise ValueError(filldedent(msg)) + + if hasattr(symbols[0], '__iter__'): + symbols = symbols[0] + + if not is_sequence(symbols) or not symbols: + msg = ('Symbols must be given, for which solution of the ' + 'system is to be found.') + raise IndexError(filldedent(msg)) + + symbols = list(map(_sympify, symbols)) + system, symbols, swap = recast_to_symbols(system, symbols) + if swap: + soln = nonlinsolve(system, symbols) + return FiniteSet(*[tuple(i.xreplace(swap) for i in s) for s in soln]) + + if len(system) == 1 and len(symbols) == 1: + return _solveset_work(system, symbols) + + # main code of def nonlinsolve() starts from here + + polys, polys_expr, nonpolys, denominators, unrad_changed = \ + _separate_poly_nonpoly(system, symbols) + + poly_eqs = [] + poly_sol = [{}] + + if polys: + poly_sol, poly_eqs = _handle_poly(polys, symbols) + if poly_sol and poly_sol[0]: + poly_syms = set().union(*(eq.free_symbols for eq in polys)) + unrad_syms = set().union(*(eq.free_symbols for eq in unrad_changed)) + if unrad_syms == poly_syms and unrad_changed: + # if all the symbols have been solved by _handle_poly + # and unrad has been used then check solutions + poly_sol = [sol for sol in poly_sol if checksol(unrad_changed, sol)] + + # Collect together the unsolved polynomials with the non-polynomial + # equations. + remaining = poly_eqs + nonpolys + + # to_tuple converts a solution dictionary to a tuple containing the + # value for each symbol + to_tuple = lambda sol: tuple(sol[s] for s in symbols) + + if not remaining: + # If there is nothing left to solve then return the solution from + # solve_poly_system directly. + return FiniteSet(*map(to_tuple, poly_sol)) + else: + # Here we handle: + # + # 1. The Groebner basis if solve_poly_system failed. + # 2. The Groebner basis in the positive-dimensional case. + # 3. Any non-polynomial equations + # + # If solve_poly_system did succeed then we pass those solutions in as + # preliminary results. + subs_res = substitution(remaining, symbols, result=poly_sol, exclude=denominators) + + if not isinstance(subs_res, FiniteSet): + return subs_res + + # check solutions produced by substitution. Currently, checking is done for + # only those solutions which have non-Set variable values. + if unrad_changed: + result = [dict(zip(symbols, sol)) for sol in subs_res.args] + correct_sols = [sol for sol in result if any(isinstance(v, Set) for v in sol) + or checksol(unrad_changed, sol) != False] + return FiniteSet(*map(to_tuple, correct_sols)) + else: + return subs_res diff --git a/.venv/lib/python3.13/site-packages/sympy/testing/__init__.py b/.venv/lib/python3.13/site-packages/sympy/testing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa66402955e7d5d2b27cccc6627b42d33f4cd855 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/testing/__init__.py @@ -0,0 +1,10 @@ +"""This module contains code for running the tests in SymPy.""" + + +from .runtests import doctest +from .runtests_pytest import test + + +__all__ = [ + 'test', 'doctest', +] diff --git a/.venv/lib/python3.13/site-packages/sympy/testing/matrices.py b/.venv/lib/python3.13/site-packages/sympy/testing/matrices.py new file mode 100644 index 0000000000000000000000000000000000000000..236a384366df7f69d0d92f43f7e007e95c12388c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/testing/matrices.py @@ -0,0 +1,8 @@ +def allclose(A, B, rtol=1e-05, atol=1e-08): + if len(A) != len(B): + return False + + for x, y in zip(A, B): + if abs(x-y) > atol + rtol * max(abs(x), abs(y)): + return False + return True diff --git a/.venv/lib/python3.13/site-packages/sympy/testing/pytest.py b/.venv/lib/python3.13/site-packages/sympy/testing/pytest.py new file mode 100644 index 0000000000000000000000000000000000000000..498515a2d3c0a167a5f7067753736898cfa64799 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/testing/pytest.py @@ -0,0 +1,392 @@ +"""py.test hacks to support XFAIL/XPASS""" + +import platform +import sys +import re +import functools +import os +import contextlib +import warnings +import inspect +import pathlib +from typing import Any, Callable + +from sympy.utilities.exceptions import SymPyDeprecationWarning +# Imported here for backwards compatibility. Note: do not import this from +# here in library code (importing sympy.pytest in library code will break the +# pytest integration). +from sympy.utilities.exceptions import ignore_warnings # noqa:F401 + +ON_CI = os.getenv('CI', None) == "true" + +try: + import pytest + USE_PYTEST = getattr(sys, '_running_pytest', False) +except ImportError: + USE_PYTEST = False + +IS_WASM: bool = sys.platform == 'emscripten' or platform.machine() in ["wasm32", "wasm64"] + +raises: Callable[[Any, Any], Any] +XFAIL: Callable[[Any], Any] +skip: Callable[[Any], Any] +SKIP: Callable[[Any], Any] +slow: Callable[[Any], Any] +tooslow: Callable[[Any], Any] +nocache_fail: Callable[[Any], Any] + + +if USE_PYTEST: + raises = pytest.raises + skip = pytest.skip + XFAIL = pytest.mark.xfail + SKIP = pytest.mark.skip + slow = pytest.mark.slow + tooslow = pytest.mark.tooslow + nocache_fail = pytest.mark.nocache_fail + from _pytest.outcomes import Failed + +else: + # Not using pytest so define the things that would have been imported from + # there. + + # _pytest._code.code.ExceptionInfo + class ExceptionInfo: + def __init__(self, value): + self.value = value + + def __repr__(self): + return "".format(self.value) + + + def raises(expectedException, code=None): + """ + Tests that ``code`` raises the exception ``expectedException``. + + ``code`` may be a callable, such as a lambda expression or function + name. + + If ``code`` is not given or None, ``raises`` will return a context + manager for use in ``with`` statements; the code to execute then + comes from the scope of the ``with``. + + ``raises()`` does nothing if the callable raises the expected exception, + otherwise it raises an AssertionError. + + Examples + ======== + + >>> from sympy.testing.pytest import raises + + >>> raises(ZeroDivisionError, lambda: 1/0) + + >>> raises(ZeroDivisionError, lambda: 1/2) + Traceback (most recent call last): + ... + Failed: DID NOT RAISE + + >>> with raises(ZeroDivisionError): + ... n = 1/0 + >>> with raises(ZeroDivisionError): + ... n = 1/2 + Traceback (most recent call last): + ... + Failed: DID NOT RAISE + + Note that you cannot test multiple statements via + ``with raises``: + + >>> with raises(ZeroDivisionError): + ... n = 1/0 # will execute and raise, aborting the ``with`` + ... n = 9999/0 # never executed + + This is just what ``with`` is supposed to do: abort the + contained statement sequence at the first exception and let + the context manager deal with the exception. + + To test multiple statements, you'll need a separate ``with`` + for each: + + >>> with raises(ZeroDivisionError): + ... n = 1/0 # will execute and raise + >>> with raises(ZeroDivisionError): + ... n = 9999/0 # will also execute and raise + + """ + if code is None: + return RaisesContext(expectedException) + elif callable(code): + try: + code() + except expectedException as e: + return ExceptionInfo(e) + raise Failed("DID NOT RAISE") + elif isinstance(code, str): + raise TypeError( + '\'raises(xxx, "code")\' has been phased out; ' + 'change \'raises(xxx, "expression")\' ' + 'to \'raises(xxx, lambda: expression)\', ' + '\'raises(xxx, "statement")\' ' + 'to \'with raises(xxx): statement\'') + else: + raise TypeError( + 'raises() expects a callable for the 2nd argument.') + + class RaisesContext: + def __init__(self, expectedException): + self.expectedException = expectedException + + def __enter__(self): + return None + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is None: + raise Failed("DID NOT RAISE") + return issubclass(exc_type, self.expectedException) + + class XFail(Exception): + pass + + class XPass(Exception): + pass + + class Skipped(Exception): + pass + + class Failed(Exception): # type: ignore + pass + + def XFAIL(func): + def wrapper(): + try: + func() + except Exception as e: + message = str(e) + if message != "Timeout": + raise XFail(func.__name__) + else: + raise Skipped("Timeout") + raise XPass(func.__name__) + + wrapper = functools.update_wrapper(wrapper, func) + return wrapper + + def skip(str): + raise Skipped(str) + + def SKIP(reason): + """Similar to ``skip()``, but this is a decorator. """ + def wrapper(func): + def func_wrapper(): + raise Skipped(reason) + + func_wrapper = functools.update_wrapper(func_wrapper, func) + return func_wrapper + + return wrapper + + def slow(func): + func._slow = True + + def func_wrapper(): + func() + + func_wrapper = functools.update_wrapper(func_wrapper, func) + func_wrapper.__wrapped__ = func + return func_wrapper + + def tooslow(func): + func._slow = True + func._tooslow = True + + def func_wrapper(): + skip("Too slow") + + func_wrapper = functools.update_wrapper(func_wrapper, func) + func_wrapper.__wrapped__ = func + return func_wrapper + + def nocache_fail(func): + "Dummy decorator for marking tests that fail when cache is disabled" + return func + +@contextlib.contextmanager +def warns(warningcls, *, match='', test_stacklevel=True): + ''' + Like raises but tests that warnings are emitted. + + >>> from sympy.testing.pytest import warns + >>> import warnings + + >>> with warns(UserWarning): + ... warnings.warn('deprecated', UserWarning, stacklevel=2) + + >>> with warns(UserWarning): + ... pass + Traceback (most recent call last): + ... + Failed: DID NOT WARN. No warnings of type UserWarning\ + was emitted. The list of emitted warnings is: []. + + ``test_stacklevel`` makes it check that the ``stacklevel`` parameter to + ``warn()`` is set so that the warning shows the user line of code (the + code under the warns() context manager). Set this to False if this is + ambiguous or if the context manager does not test the direct user code + that emits the warning. + + If the warning is a ``SymPyDeprecationWarning``, this additionally tests + that the ``active_deprecations_target`` is a real target in the + ``active-deprecations.md`` file. + + ''' + # Absorbs all warnings in warnrec + with warnings.catch_warnings(record=True) as warnrec: + # Any warning other than the one we are looking for is an error + warnings.simplefilter("error") + warnings.filterwarnings("always", category=warningcls) + # Now run the test + yield warnrec + + # Raise if expected warning not found + if not any(issubclass(w.category, warningcls) for w in warnrec): + msg = ('Failed: DID NOT WARN.' + ' No warnings of type %s was emitted.' + ' The list of emitted warnings is: %s.' + ) % (warningcls, [w.message for w in warnrec]) + raise Failed(msg) + + # We don't include the match in the filter above because it would then + # fall to the error filter, so we instead manually check that it matches + # here + for w in warnrec: + # Should always be true due to the filters above + assert issubclass(w.category, warningcls) + if not re.compile(match, re.IGNORECASE).match(str(w.message)): + raise Failed(f"Failed: WRONG MESSAGE. A warning with of the correct category ({warningcls.__name__}) was issued, but it did not match the given match regex ({match!r})") + + if test_stacklevel: + for f in inspect.stack(): + thisfile = f.filename + file = os.path.split(thisfile)[1] + if file.startswith('test_'): + break + elif file == 'doctest.py': + # skip the stacklevel testing in the doctests of this + # function + return + else: + raise RuntimeError("Could not find the file for the given warning to test the stacklevel") + for w in warnrec: + if w.filename != thisfile: + msg = f'''\ +Failed: Warning has the wrong stacklevel. The warning stacklevel needs to be +set so that the line of code shown in the warning message is user code that +calls the deprecated code (the current stacklevel is showing code from +{w.filename} (line {w.lineno}), expected {thisfile})'''.replace('\n', ' ') + raise Failed(msg) + + if warningcls == SymPyDeprecationWarning: + this_file = pathlib.Path(__file__) + active_deprecations_file = (this_file.parent.parent.parent / 'doc' / + 'src' / 'explanation' / + 'active-deprecations.md') + if not active_deprecations_file.exists(): + # We can only test that the active_deprecations_target works if we are + # in the git repo. + return + targets = [] + for w in warnrec: + targets.append(w.message.active_deprecations_target) + text = pathlib.Path(active_deprecations_file).read_text(encoding="utf-8") + for target in targets: + if f'({target})=' not in text: + raise Failed(f"The active deprecations target {target!r} does not appear to be a valid target in the active-deprecations.md file ({active_deprecations_file}).") + +def _both_exp_pow(func): + """ + Decorator used to run the test twice: the first time `e^x` is represented + as ``Pow(E, x)``, the second time as ``exp(x)`` (exponential object is not + a power). + + This is a temporary trick helping to manage the elimination of the class + ``exp`` in favor of a replacement by ``Pow(E, ...)``. + """ + from sympy.core.parameters import _exp_is_pow + + def func_wrap(): + with _exp_is_pow(True): + func() + with _exp_is_pow(False): + func() + + wrapper = functools.update_wrapper(func_wrap, func) + return wrapper + + +@contextlib.contextmanager +def warns_deprecated_sympy(): + ''' + Shorthand for ``warns(SymPyDeprecationWarning)`` + + This is the recommended way to test that ``SymPyDeprecationWarning`` is + emitted for deprecated features in SymPy. To test for other warnings use + ``warns``. To suppress warnings without asserting that they are emitted + use ``ignore_warnings``. + + .. note:: + + ``warns_deprecated_sympy()`` is only intended for internal use in the + SymPy test suite to test that a deprecation warning triggers properly. + All other code in the SymPy codebase, including documentation examples, + should not use deprecated behavior. + + If you are a user of SymPy and you want to disable + SymPyDeprecationWarnings, use ``warnings`` filters (see + :ref:`silencing-sympy-deprecation-warnings`). + + >>> from sympy.testing.pytest import warns_deprecated_sympy + >>> from sympy.utilities.exceptions import sympy_deprecation_warning + >>> with warns_deprecated_sympy(): + ... sympy_deprecation_warning("Don't use", + ... deprecated_since_version="1.0", + ... active_deprecations_target="active-deprecations") + + >>> with warns_deprecated_sympy(): + ... pass + Traceback (most recent call last): + ... + Failed: DID NOT WARN. No warnings of type \ + SymPyDeprecationWarning was emitted. The list of emitted warnings is: []. + + .. note:: + + Sometimes the stacklevel test will fail because the same warning is + emitted multiple times. In this case, you can use + :func:`sympy.utilities.exceptions.ignore_warnings` in the code to + prevent the ``SymPyDeprecationWarning`` from being emitted again + recursively. In rare cases it is impossible to have a consistent + ``stacklevel`` for deprecation warnings because different ways of + calling a function will produce different call stacks.. In those cases, + use ``warns(SymPyDeprecationWarning)`` instead. + + See Also + ======== + sympy.utilities.exceptions.SymPyDeprecationWarning + sympy.utilities.exceptions.sympy_deprecation_warning + sympy.utilities.decorator.deprecated + + ''' + with warns(SymPyDeprecationWarning): + yield + + +def skip_under_pyodide(message): + """Decorator to skip a test if running under Pyodide/WASM.""" + def decorator(test_func): + @functools.wraps(test_func) + def test_wrapper(): + if IS_WASM: + skip(message) + return test_func() + return test_wrapper + return decorator diff --git a/.venv/lib/python3.13/site-packages/sympy/testing/quality_unicode.py b/.venv/lib/python3.13/site-packages/sympy/testing/quality_unicode.py new file mode 100644 index 0000000000000000000000000000000000000000..d43623ff5112610e377347f50c6a40a15810644b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/testing/quality_unicode.py @@ -0,0 +1,102 @@ +import re +import fnmatch + + +message_unicode_B = \ + "File contains a unicode character : %s, line %s. " \ + "But not in the whitelist. " \ + "Add the file to the whitelist in " + __file__ +message_unicode_D = \ + "File does not contain a unicode character : %s." \ + "but is in the whitelist. " \ + "Remove the file from the whitelist in " + __file__ + + +encoding_header_re = re.compile( + r'^[ \t\f]*#.*?coding[:=][ \t]*([-_.a-zA-Z0-9]+)') + +# Whitelist pattern for files which can have unicode. +unicode_whitelist = [ + # Author names can include non-ASCII characters + r'*/bin/authors_update.py', + r'*/bin/mailmap_check.py', + + # These files have functions and test functions for unicode input and + # output. + r'*/sympy/testing/tests/test_code_quality.py', + r'*/sympy/physics/vector/tests/test_printing.py', + r'*/physics/quantum/tests/test_printing.py', + r'*/sympy/vector/tests/test_printing.py', + r'*/sympy/parsing/tests/test_sympy_parser.py', + r'*/sympy/printing/pretty/stringpict.py', + r'*/sympy/printing/pretty/tests/test_pretty.py', + r'*/sympy/printing/tests/test_conventions.py', + r'*/sympy/printing/tests/test_preview.py', + r'*/liealgebras/type_g.py', + r'*/liealgebras/weyl_group.py', + r'*/liealgebras/tests/test_type_G.py', + + # wigner.py and polarization.py have unicode doctests. These probably + # don't need to be there but some of the examples that are there are + # pretty ugly without use_unicode (matrices need to be wrapped across + # multiple lines etc) + r'*/sympy/physics/wigner.py', + r'*/sympy/physics/optics/polarization.py', + + # joint.py uses some unicode for variable names in the docstrings + r'*/sympy/physics/mechanics/joint.py', + + # lll method has unicode in docstring references and author name + r'*/sympy/polys/matrices/domainmatrix.py', + r'*/sympy/matrices/repmatrix.py', + + # Explanation of symbols uses greek letters + r'*/sympy/core/symbol.py', +] + +unicode_strict_whitelist = [ + r'*/sympy/parsing/latex/_antlr/__init__.py', + # test_mathematica.py uses some unicode for testing Greek characters are working #24055 + r'*/sympy/parsing/tests/test_mathematica.py', +] + + +def _test_this_file_encoding( + fname, test_file, + unicode_whitelist=unicode_whitelist, + unicode_strict_whitelist=unicode_strict_whitelist): + """Test helper function for unicode test + + The test may have to operate on filewise manner, so it had moved + to a separate process. + """ + has_unicode = False + + is_in_whitelist = False + is_in_strict_whitelist = False + for patt in unicode_whitelist: + if fnmatch.fnmatch(fname, patt): + is_in_whitelist = True + break + for patt in unicode_strict_whitelist: + if fnmatch.fnmatch(fname, patt): + is_in_strict_whitelist = True + is_in_whitelist = True + break + + if is_in_whitelist: + for idx, line in enumerate(test_file): + try: + line.encode(encoding='ascii') + except (UnicodeEncodeError, UnicodeDecodeError): + has_unicode = True + + if not has_unicode and not is_in_strict_whitelist: + assert False, message_unicode_D % fname + + else: + for idx, line in enumerate(test_file): + try: + line.encode(encoding='ascii') + except (UnicodeEncodeError, UnicodeDecodeError): + assert False, message_unicode_B % (fname, idx + 1) diff --git a/.venv/lib/python3.13/site-packages/sympy/testing/randtest.py b/.venv/lib/python3.13/site-packages/sympy/testing/randtest.py new file mode 100644 index 0000000000000000000000000000000000000000..3ce2c8c031eec1c886532daba32c96d83e9cf85c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/testing/randtest.py @@ -0,0 +1,19 @@ +""" +.. deprecated:: 1.10 + + ``sympy.testing.randtest`` functions have been moved to + :mod:`sympy.core.random`. + +""" +from sympy.utilities.exceptions import sympy_deprecation_warning + +sympy_deprecation_warning("The sympy.testing.randtest submodule is deprecated. Use sympy.core.random instead.", + deprecated_since_version="1.10", + active_deprecations_target="deprecated-sympy-testing-randtest") + +from sympy.core.random import ( # noqa:F401 + random_complex_number, + verify_numerically, + test_derivative_numerically, + _randrange, + _randint) diff --git a/.venv/lib/python3.13/site-packages/sympy/testing/runtests.py b/.venv/lib/python3.13/site-packages/sympy/testing/runtests.py new file mode 100644 index 0000000000000000000000000000000000000000..e2650e4e6dabaaa07fc25c76cce3d9d28723b0d5 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/testing/runtests.py @@ -0,0 +1,2409 @@ +""" +This is our testing framework. + +Goals: + +* it should be compatible with py.test and operate very similarly + (or identically) +* does not require any external dependencies +* preferably all the functionality should be in this file only +* no magic, just import the test file and execute the test functions, that's it +* portable + +""" + +import os +import sys +import platform +import inspect +import traceback +import pdb +import re +import linecache +import time +from fnmatch import fnmatch +from timeit import default_timer as clock +import doctest as pdoctest # avoid clashing with our doctest() function +from doctest import DocTestFinder, DocTestRunner +import random +import subprocess +import shutil +import signal +import stat +import tempfile +import warnings +from contextlib import contextmanager +from inspect import unwrap +from pathlib import Path + +from sympy.core.cache import clear_cache +from sympy.external import import_module +from sympy.external.gmpy import GROUND_TYPES + +IS_WINDOWS = (os.name == 'nt') +ON_CI = os.getenv('CI', None) + +# empirically generated list of the proportion of time spent running +# an even split of tests. This should periodically be regenerated. +# A list of [.6, .1, .3] would mean that if the tests are evenly split +# into '1/3', '2/3', '3/3', the first split would take 60% of the time, +# the second 10% and the third 30%. These lists are normalized to sum +# to 1, so [60, 10, 30] has the same behavior as [6, 1, 3] or [.6, .1, .3]. +# +# This list can be generated with the code: +# from time import time +# import sympy +# import os +# os.environ["CI"] = 'true' # Mock CI to get more correct densities +# delays, num_splits = [], 30 +# for i in range(1, num_splits + 1): +# tic = time() +# sympy.test(split='{}/{}'.format(i, num_splits), time_balance=False) # Add slow=True for slow tests +# delays.append(time() - tic) +# tot = sum(delays) +# print([round(x / tot, 4) for x in delays]) +SPLIT_DENSITY = [ + 0.0059, 0.0027, 0.0068, 0.0011, 0.0006, + 0.0058, 0.0047, 0.0046, 0.004, 0.0257, + 0.0017, 0.0026, 0.004, 0.0032, 0.0016, + 0.0015, 0.0004, 0.0011, 0.0016, 0.0014, + 0.0077, 0.0137, 0.0217, 0.0074, 0.0043, + 0.0067, 0.0236, 0.0004, 0.1189, 0.0142, + 0.0234, 0.0003, 0.0003, 0.0047, 0.0006, + 0.0013, 0.0004, 0.0008, 0.0007, 0.0006, + 0.0139, 0.0013, 0.0007, 0.0051, 0.002, + 0.0004, 0.0005, 0.0213, 0.0048, 0.0016, + 0.0012, 0.0014, 0.0024, 0.0015, 0.0004, + 0.0005, 0.0007, 0.011, 0.0062, 0.0015, + 0.0021, 0.0049, 0.0006, 0.0006, 0.0011, + 0.0006, 0.0019, 0.003, 0.0044, 0.0054, + 0.0057, 0.0049, 0.0016, 0.0006, 0.0009, + 0.0006, 0.0012, 0.0006, 0.0149, 0.0532, + 0.0076, 0.0041, 0.0024, 0.0135, 0.0081, + 0.2209, 0.0459, 0.0438, 0.0488, 0.0137, + 0.002, 0.0003, 0.0008, 0.0039, 0.0024, + 0.0005, 0.0004, 0.003, 0.056, 0.0026] +SPLIT_DENSITY_SLOW = [0.0086, 0.0004, 0.0568, 0.0003, 0.0032, 0.0005, 0.0004, 0.0013, 0.0016, 0.0648, 0.0198, 0.1285, 0.098, 0.0005, 0.0064, 0.0003, 0.0004, 0.0026, 0.0007, 0.0051, 0.0089, 0.0024, 0.0033, 0.0057, 0.0005, 0.0003, 0.001, 0.0045, 0.0091, 0.0006, 0.0005, 0.0321, 0.0059, 0.1105, 0.216, 0.1489, 0.0004, 0.0003, 0.0006, 0.0483] + +class Skipped(Exception): + pass + +class TimeOutError(Exception): + pass + +class DependencyError(Exception): + pass + + +def _indent(s, indent=4): + """ + Add the given number of space characters to the beginning of + every non-blank line in ``s``, and return the result. + If the string ``s`` is Unicode, it is encoded using the stdout + encoding and the ``backslashreplace`` error handler. + """ + # This regexp matches the start of non-blank lines: + return re.sub('(?m)^(?!$)', indent*' ', s) + + +pdoctest._indent = _indent # type: ignore + +# override reporter to maintain windows and python3 + + +def _report_failure(self, out, test, example, got): + """ + Report that the given example failed. + """ + s = self._checker.output_difference(example, got, self.optionflags) + s = s.encode('raw_unicode_escape').decode('utf8', 'ignore') + out(self._failure_header(test, example) + s) + + +if IS_WINDOWS: + DocTestRunner.report_failure = _report_failure # type: ignore + + +def convert_to_native_paths(lst): + """ + Converts a list of '/' separated paths into a list of + native (os.sep separated) paths and converts to lowercase + if the system is case insensitive. + """ + newlst = [] + for rv in lst: + rv = os.path.join(*rv.split("/")) + # on windows the slash after the colon is dropped + if sys.platform == "win32": + pos = rv.find(':') + if pos != -1: + if rv[pos + 1] != '\\': + rv = rv[:pos + 1] + '\\' + rv[pos + 1:] + newlst.append(os.path.normcase(rv)) + return newlst + + +def get_sympy_dir(): + """ + Returns the root SymPy directory and set the global value + indicating whether the system is case sensitive or not. + """ + this_file = os.path.abspath(__file__) + sympy_dir = os.path.join(os.path.dirname(this_file), "..", "..") + sympy_dir = os.path.normpath(sympy_dir) + return os.path.normcase(sympy_dir) + + +def setup_pprint(disable_line_wrap=True): + from sympy.interactive.printing import init_printing + from sympy.printing.pretty.pretty import pprint_use_unicode + import sympy.interactive.printing as interactive_printing + from sympy.printing.pretty import stringpict + + # Prevent init_printing() in doctests from affecting other doctests + interactive_printing.NO_GLOBAL = True + + # force pprint to be in ascii mode in doctests + use_unicode_prev = pprint_use_unicode(False) + + # disable line wrapping for pprint() outputs + wrap_line_prev = stringpict._GLOBAL_WRAP_LINE + if disable_line_wrap: + stringpict._GLOBAL_WRAP_LINE = False + + # hook our nice, hash-stable strprinter + init_printing(pretty_print=False) + + return use_unicode_prev, wrap_line_prev + + +@contextmanager +def raise_on_deprecated(): + """Context manager to make DeprecationWarning raise an error + + This is to catch SymPyDeprecationWarning from library code while running + tests and doctests. It is important to use this context manager around + each individual test/doctest in case some tests modify the warning + filters. + """ + with warnings.catch_warnings(): + warnings.filterwarnings('error', '.*', DeprecationWarning, module='sympy.*') + yield + + +def run_in_subprocess_with_hash_randomization( + function, function_args=(), + function_kwargs=None, command=sys.executable, + module='sympy.testing.runtests', force=False): + """ + Run a function in a Python subprocess with hash randomization enabled. + + If hash randomization is not supported by the version of Python given, it + returns False. Otherwise, it returns the exit value of the command. The + function is passed to sys.exit(), so the return value of the function will + be the return value. + + The environment variable PYTHONHASHSEED is used to seed Python's hash + randomization. If it is set, this function will return False, because + starting a new subprocess is unnecessary in that case. If it is not set, + one is set at random, and the tests are run. Note that if this + environment variable is set when Python starts, hash randomization is + automatically enabled. To force a subprocess to be created even if + PYTHONHASHSEED is set, pass ``force=True``. This flag will not force a + subprocess in Python versions that do not support hash randomization (see + below), because those versions of Python do not support the ``-R`` flag. + + ``function`` should be a string name of a function that is importable from + the module ``module``, like "_test". The default for ``module`` is + "sympy.testing.runtests". ``function_args`` and ``function_kwargs`` + should be a repr-able tuple and dict, respectively. The default Python + command is sys.executable, which is the currently running Python command. + + This function is necessary because the seed for hash randomization must be + set by the environment variable before Python starts. Hence, in order to + use a predetermined seed for tests, we must start Python in a separate + subprocess. + + Hash randomization was added in the minor Python versions 2.6.8, 2.7.3, + 3.1.5, and 3.2.3, and is enabled by default in all Python versions after + and including 3.3.0. + + Examples + ======== + + >>> from sympy.testing.runtests import ( + ... run_in_subprocess_with_hash_randomization) + >>> # run the core tests in verbose mode + >>> run_in_subprocess_with_hash_randomization("_test", + ... function_args=("core",), + ... function_kwargs={'verbose': True}) # doctest: +SKIP + # Will return 0 if sys.executable supports hash randomization and tests + # pass, 1 if they fail, and False if it does not support hash + # randomization. + + """ + cwd = get_sympy_dir() + # Note, we must return False everywhere, not None, as subprocess.call will + # sometimes return None. + + # First check if the Python version supports hash randomization + # If it does not have this support, it won't recognize the -R flag + p = subprocess.Popen([command, "-RV"], stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, cwd=cwd) + p.communicate() + if p.returncode != 0: + return False + + hash_seed = os.getenv("PYTHONHASHSEED") + if not hash_seed: + os.environ["PYTHONHASHSEED"] = str(random.randrange(2**32)) + else: + if not force: + return False + + function_kwargs = function_kwargs or {} + + # Now run the command + commandstring = ("import sys; from %s import %s;sys.exit(%s(*%s, **%s))" % + (module, function, function, repr(function_args), + repr(function_kwargs))) + + try: + p = subprocess.Popen([command, "-R", "-c", commandstring], cwd=cwd) + p.communicate() + except KeyboardInterrupt: + p.wait() + finally: + # Put the environment variable back, so that it reads correctly for + # the current Python process. + if hash_seed is None: + del os.environ["PYTHONHASHSEED"] + else: + os.environ["PYTHONHASHSEED"] = hash_seed + return p.returncode + + +def run_all_tests(test_args=(), test_kwargs=None, + doctest_args=(), doctest_kwargs=None, + examples_args=(), examples_kwargs=None): + """ + Run all tests. + + Right now, this runs the regular tests (bin/test), the doctests + (bin/doctest), and the examples (examples/all.py). + + This is what ``setup.py test`` uses. + + You can pass arguments and keyword arguments to the test functions that + support them (for now, test, doctest, and the examples). See the + docstrings of those functions for a description of the available options. + + For example, to run the solvers tests with colors turned off: + + >>> from sympy.testing.runtests import run_all_tests + >>> run_all_tests(test_args=("solvers",), + ... test_kwargs={"colors:False"}) # doctest: +SKIP + + """ + tests_successful = True + + test_kwargs = test_kwargs or {} + doctest_kwargs = doctest_kwargs or {} + examples_kwargs = examples_kwargs or {'quiet': True} + + try: + # Regular tests + if not test(*test_args, **test_kwargs): + # some regular test fails, so set the tests_successful + # flag to false and continue running the doctests + tests_successful = False + + # Doctests + print() + if not doctest(*doctest_args, **doctest_kwargs): + tests_successful = False + + # Examples + print() + sys.path.append("examples") # examples/all.py + from all import run_examples # type: ignore + if not run_examples(*examples_args, **examples_kwargs): + tests_successful = False + + if tests_successful: + return + else: + # Return nonzero exit code + sys.exit(1) + except KeyboardInterrupt: + print() + print("DO *NOT* COMMIT!") + sys.exit(1) + + +def test(*paths, subprocess=True, rerun=0, **kwargs): + """ + Run tests in the specified test_*.py files. + + Tests in a particular test_*.py file are run if any of the given strings + in ``paths`` matches a part of the test file's path. If ``paths=[]``, + tests in all test_*.py files are run. + + Notes: + + - If sort=False, tests are run in random order (not default). + - Paths can be entered in native system format or in unix, + forward-slash format. + - Files that are on the blacklist can be tested by providing + their path; they are only excluded if no paths are given. + + **Explanation of test results** + + ====== =============================================================== + Output Meaning + ====== =============================================================== + . passed + F failed + X XPassed (expected to fail but passed) + f XFAILed (expected to fail and indeed failed) + s skipped + w slow + T timeout (e.g., when ``--timeout`` is used) + K KeyboardInterrupt (when running the slow tests with ``--slow``, + you can interrupt one of them without killing the test runner) + ====== =============================================================== + + + Colors have no additional meaning and are used just to facilitate + interpreting the output. + + Examples + ======== + + >>> import sympy + + Run all tests: + + >>> sympy.test() # doctest: +SKIP + + Run one file: + + >>> sympy.test("sympy/core/tests/test_basic.py") # doctest: +SKIP + >>> sympy.test("_basic") # doctest: +SKIP + + Run all tests in sympy/functions/ and some particular file: + + >>> sympy.test("sympy/core/tests/test_basic.py", + ... "sympy/functions") # doctest: +SKIP + + Run all tests in sympy/core and sympy/utilities: + + >>> sympy.test("/core", "/util") # doctest: +SKIP + + Run specific test from a file: + + >>> sympy.test("sympy/core/tests/test_basic.py", + ... kw="test_equality") # doctest: +SKIP + + Run specific test from any file: + + >>> sympy.test(kw="subs") # doctest: +SKIP + + Run the tests with verbose mode on: + + >>> sympy.test(verbose=True) # doctest: +SKIP + + Do not sort the test output: + + >>> sympy.test(sort=False) # doctest: +SKIP + + Turn on post-mortem pdb: + + >>> sympy.test(pdb=True) # doctest: +SKIP + + Turn off colors: + + >>> sympy.test(colors=False) # doctest: +SKIP + + Force colors, even when the output is not to a terminal (this is useful, + e.g., if you are piping to ``less -r`` and you still want colors) + + >>> sympy.test(force_colors=False) # doctest: +SKIP + + The traceback verboseness can be set to "short" or "no" (default is + "short") + + >>> sympy.test(tb='no') # doctest: +SKIP + + The ``split`` option can be passed to split the test run into parts. The + split currently only splits the test files, though this may change in the + future. ``split`` should be a string of the form 'a/b', which will run + part ``a`` of ``b``. For instance, to run the first half of the test suite: + + >>> sympy.test(split='1/2') # doctest: +SKIP + + The ``time_balance`` option can be passed in conjunction with ``split``. + If ``time_balance=True`` (the default for ``sympy.test``), SymPy will attempt + to split the tests such that each split takes equal time. This heuristic + for balancing is based on pre-recorded test data. + + >>> sympy.test(split='1/2', time_balance=True) # doctest: +SKIP + + You can disable running the tests in a separate subprocess using + ``subprocess=False``. This is done to support seeding hash randomization, + which is enabled by default in the Python versions where it is supported. + If subprocess=False, hash randomization is enabled/disabled according to + whether it has been enabled or not in the calling Python process. + However, even if it is enabled, the seed cannot be printed unless it is + called from a new Python process. + + Hash randomization was added in the minor Python versions 2.6.8, 2.7.3, + 3.1.5, and 3.2.3, and is enabled by default in all Python versions after + and including 3.3.0. + + If hash randomization is not supported ``subprocess=False`` is used + automatically. + + >>> sympy.test(subprocess=False) # doctest: +SKIP + + To set the hash randomization seed, set the environment variable + ``PYTHONHASHSEED`` before running the tests. This can be done from within + Python using + + >>> import os + >>> os.environ['PYTHONHASHSEED'] = '42' # doctest: +SKIP + + Or from the command line using + + $ PYTHONHASHSEED=42 ./bin/test + + If the seed is not set, a random seed will be chosen. + + Note that to reproduce the same hash values, you must use both the same seed + as well as the same architecture (32-bit vs. 64-bit). + + """ + # count up from 0, do not print 0 + print_counter = lambda i : (print("rerun %d" % (rerun-i)) + if rerun-i else None) + + if subprocess: + # loop backwards so last i is 0 + for i in range(rerun, -1, -1): + print_counter(i) + ret = run_in_subprocess_with_hash_randomization("_test", + function_args=paths, function_kwargs=kwargs) + if ret is False: + break + val = not bool(ret) + # exit on the first failure or if done + if not val or i == 0: + return val + + # rerun even if hash randomization is not supported + for i in range(rerun, -1, -1): + print_counter(i) + val = not bool(_test(*paths, **kwargs)) + if not val or i == 0: + return val + + +def _test(*paths, + verbose=False, tb="short", kw=None, pdb=False, colors=True, + force_colors=False, sort=True, seed=None, timeout=False, + fail_on_timeout=False, slow=False, enhance_asserts=False, split=None, + time_balance=True, blacklist=(), + fast_threshold=None, slow_threshold=None): + """ + Internal function that actually runs the tests. + + All keyword arguments from ``test()`` are passed to this function except for + ``subprocess``. + + Returns 0 if tests passed and 1 if they failed. See the docstring of + ``test()`` for more information. + """ + kw = kw or () + # ensure that kw is a tuple + if isinstance(kw, str): + kw = (kw,) + post_mortem = pdb + if seed is None: + seed = random.randrange(100000000) + if ON_CI and timeout is False: + timeout = 595 + fail_on_timeout = True + if ON_CI: + blacklist = list(blacklist) + ['sympy/plotting/pygletplot/tests'] + blacklist = convert_to_native_paths(blacklist) + r = PyTestReporter(verbose=verbose, tb=tb, colors=colors, + force_colors=force_colors, split=split) + # This won't strictly run the test for the corresponding file, but it is + # good enough for copying and pasting the failing test. + _paths = [] + for path in paths: + if '::' in path: + path, _kw = path.split('::', 1) + kw += (_kw,) + _paths.append(path) + paths = _paths + + t = SymPyTests(r, kw, post_mortem, seed, + fast_threshold=fast_threshold, + slow_threshold=slow_threshold) + + test_files = t.get_test_files('sympy') + + not_blacklisted = [f for f in test_files + if not any(b in f for b in blacklist)] + + if len(paths) == 0: + matched = not_blacklisted + else: + paths = convert_to_native_paths(paths) + matched = [] + for f in not_blacklisted: + basename = os.path.basename(f) + for p in paths: + if p in f or fnmatch(basename, p): + matched.append(f) + break + + density = None + if time_balance: + if slow: + density = SPLIT_DENSITY_SLOW + else: + density = SPLIT_DENSITY + + if split: + matched = split_list(matched, split, density=density) + + t._testfiles.extend(matched) + + return int(not t.test(sort=sort, timeout=timeout, slow=slow, + enhance_asserts=enhance_asserts, fail_on_timeout=fail_on_timeout)) + + +def doctest(*paths, subprocess=True, rerun=0, **kwargs): + r""" + Runs doctests in all \*.py files in the SymPy directory which match + any of the given strings in ``paths`` or all tests if paths=[]. + + Notes: + + - Paths can be entered in native system format or in unix, + forward-slash format. + - Files that are on the blacklist can be tested by providing + their path; they are only excluded if no paths are given. + + Examples + ======== + + >>> import sympy + + Run all tests: + + >>> sympy.doctest() # doctest: +SKIP + + Run one file: + + >>> sympy.doctest("sympy/core/basic.py") # doctest: +SKIP + >>> sympy.doctest("polynomial.rst") # doctest: +SKIP + + Run all tests in sympy/functions/ and some particular file: + + >>> sympy.doctest("/functions", "basic.py") # doctest: +SKIP + + Run any file having polynomial in its name, doc/src/modules/polynomial.rst, + sympy/functions/special/polynomials.py, and sympy/polys/polynomial.py: + + >>> sympy.doctest("polynomial") # doctest: +SKIP + + The ``split`` option can be passed to split the test run into parts. The + split currently only splits the test files, though this may change in the + future. ``split`` should be a string of the form 'a/b', which will run + part ``a`` of ``b``. Note that the regular doctests and the Sphinx + doctests are split independently. For instance, to run the first half of + the test suite: + + >>> sympy.doctest(split='1/2') # doctest: +SKIP + + The ``subprocess`` and ``verbose`` options are the same as with the function + ``test()`` (see the docstring of that function for more information) except + that ``verbose`` may also be set equal to ``2`` in order to print + individual doctest lines, as they are being tested. + """ + # count up from 0, do not print 0 + print_counter = lambda i : (print("rerun %d" % (rerun-i)) + if rerun-i else None) + + if subprocess: + # loop backwards so last i is 0 + for i in range(rerun, -1, -1): + print_counter(i) + ret = run_in_subprocess_with_hash_randomization("_doctest", + function_args=paths, function_kwargs=kwargs) + if ret is False: + break + val = not bool(ret) + # exit on the first failure or if done + if not val or i == 0: + return val + + # rerun even if hash randomization is not supported + for i in range(rerun, -1, -1): + print_counter(i) + val = not bool(_doctest(*paths, **kwargs)) + if not val or i == 0: + return val + + +def _get_doctest_blacklist(): + '''Get the default blacklist for the doctests''' + blacklist = [] + + blacklist.extend([ + "doc/src/modules/plotting.rst", # generates live plots + "doc/src/modules/physics/mechanics/autolev_parser.rst", + "sympy/codegen/array_utils.py", # raises deprecation warning + "sympy/core/compatibility.py", # backwards compatibility shim, importing it triggers a deprecation warning + "sympy/core/trace.py", # backwards compatibility shim, importing it triggers a deprecation warning + "sympy/galgebra.py", # no longer part of SymPy + "sympy/parsing/autolev/_antlr/autolevlexer.py", # generated code + "sympy/parsing/autolev/_antlr/autolevlistener.py", # generated code + "sympy/parsing/autolev/_antlr/autolevparser.py", # generated code + "sympy/parsing/latex/_antlr/latexlexer.py", # generated code + "sympy/parsing/latex/_antlr/latexparser.py", # generated code + "sympy/plotting/pygletplot/__init__.py", # crashes on some systems + "sympy/plotting/pygletplot/plot.py", # crashes on some systems + "sympy/printing/ccode.py", # backwards compatibility shim, importing it breaks the codegen doctests + "sympy/printing/cxxcode.py", # backwards compatibility shim, importing it breaks the codegen doctests + "sympy/printing/fcode.py", # backwards compatibility shim, importing it breaks the codegen doctests + "sympy/testing/randtest.py", # backwards compatibility shim, importing it triggers a deprecation warning + "sympy/this.py", # prints text + ]) + # autolev parser tests + num = 12 + for i in range (1, num+1): + blacklist.append("sympy/parsing/autolev/test-examples/ruletest" + str(i) + ".py") + blacklist.extend(["sympy/parsing/autolev/test-examples/pydy-example-repo/mass_spring_damper.py", + "sympy/parsing/autolev/test-examples/pydy-example-repo/chaos_pendulum.py", + "sympy/parsing/autolev/test-examples/pydy-example-repo/double_pendulum.py", + "sympy/parsing/autolev/test-examples/pydy-example-repo/non_min_pendulum.py"]) + + if import_module('numpy') is None: + blacklist.extend([ + "sympy/plotting/experimental_lambdify.py", + "sympy/plotting/plot_implicit.py", + "examples/advanced/autowrap_integrators.py", + "examples/advanced/autowrap_ufuncify.py", + "examples/intermediate/sample.py", + "examples/intermediate/mplot2d.py", + "examples/intermediate/mplot3d.py", + "doc/src/modules/numeric-computation.rst", + "doc/src/explanation/best-practices.md", + "doc/src/tutorials/physics/biomechanics/biomechanical-model-example.rst", + "doc/src/tutorials/physics/biomechanics/biomechanics.rst", + ]) + else: + if import_module('matplotlib') is None: + blacklist.extend([ + "examples/intermediate/mplot2d.py", + "examples/intermediate/mplot3d.py" + ]) + else: + # Use a non-windowed backend, so that the tests work on CI + import matplotlib + matplotlib.use('Agg') + + if ON_CI or import_module('pyglet') is None: + blacklist.extend(["sympy/plotting/pygletplot"]) + + if import_module('aesara') is None: + blacklist.extend([ + "sympy/printing/aesaracode.py", + "doc/src/modules/numeric-computation.rst", + ]) + + if import_module('cupy') is None: + blacklist.extend([ + "doc/src/modules/numeric-computation.rst", + ]) + + if import_module('jax') is None: + blacklist.extend([ + "doc/src/modules/numeric-computation.rst", + ]) + + if import_module('antlr4') is None: + blacklist.extend([ + "sympy/parsing/autolev/__init__.py", + "sympy/parsing/latex/_parse_latex_antlr.py", + ]) + + if import_module('lfortran') is None: + #throws ImportError when lfortran not installed + blacklist.extend([ + "sympy/parsing/sym_expr.py", + ]) + + if import_module("scipy") is None: + # throws ModuleNotFoundError when scipy not installed + blacklist.extend([ + "doc/src/guides/solving/solve-numerically.md", + "doc/src/guides/solving/solve-ode.md", + ]) + + if import_module("numpy") is None: + # throws ModuleNotFoundError when numpy not installed + blacklist.extend([ + "doc/src/guides/solving/solve-ode.md", + "doc/src/guides/solving/solve-numerically.md", + ]) + + # disabled because of doctest failures in asmeurer's bot + blacklist.extend([ + "sympy/utilities/autowrap.py", + "examples/advanced/autowrap_integrators.py", + "examples/advanced/autowrap_ufuncify.py" + ]) + + blacklist.extend([ + "sympy/conftest.py", # Depends on pytest + ]) + + # These are deprecated stubs to be removed: + blacklist.extend([ + "sympy/utilities/tmpfiles.py", + "sympy/utilities/pytest.py", + "sympy/utilities/runtests.py", + "sympy/utilities/quality_unicode.py", + "sympy/utilities/randtest.py", + ]) + + blacklist = convert_to_native_paths(blacklist) + return blacklist + + +def _doctest(*paths, **kwargs): + """ + Internal function that actually runs the doctests. + + All keyword arguments from ``doctest()`` are passed to this function + except for ``subprocess``. + + Returns 0 if tests passed and 1 if they failed. See the docstrings of + ``doctest()`` and ``test()`` for more information. + """ + from sympy.printing.pretty.pretty import pprint_use_unicode + from sympy.printing.pretty import stringpict + + normal = kwargs.get("normal", False) + verbose = kwargs.get("verbose", False) + colors = kwargs.get("colors", True) + force_colors = kwargs.get("force_colors", False) + blacklist = kwargs.get("blacklist", []) + split = kwargs.get('split', None) + + blacklist.extend(_get_doctest_blacklist()) + + # Use a non-windowed backend, so that the tests work on CI + if import_module('matplotlib') is not None: + import matplotlib + matplotlib.use('Agg') + + # Disable warnings for external modules + import sympy.external + sympy.external.importtools.WARN_OLD_VERSION = False + sympy.external.importtools.WARN_NOT_INSTALLED = False + + # Disable showing up of plots + from sympy.plotting.plot import unset_show + unset_show() + + r = PyTestReporter(verbose, split=split, colors=colors,\ + force_colors=force_colors) + t = SymPyDocTests(r, normal) + + test_files = t.get_test_files('sympy') + test_files.extend(t.get_test_files('examples', init_only=False)) + + not_blacklisted = [f for f in test_files + if not any(b in f for b in blacklist)] + if len(paths) == 0: + matched = not_blacklisted + else: + # take only what was requested...but not blacklisted items + # and allow for partial match anywhere or fnmatch of name + paths = convert_to_native_paths(paths) + matched = [] + for f in not_blacklisted: + basename = os.path.basename(f) + for p in paths: + if p in f or fnmatch(basename, p): + matched.append(f) + break + + matched.sort() + + if split: + matched = split_list(matched, split) + + t._testfiles.extend(matched) + + # run the tests and record the result for this *py portion of the tests + if t._testfiles: + failed = not t.test() + else: + failed = False + + # N.B. + # -------------------------------------------------------------------- + # Here we test *.rst and *.md files at or below doc/src. Code from these + # must be self supporting in terms of imports since there is no importing + # of necessary modules by doctest.testfile. If you try to pass *.py files + # through this they might fail because they will lack the needed imports + # and smarter parsing that can be done with source code. + # + test_files_rst = t.get_test_files('doc/src', '*.rst', init_only=False) + test_files_md = t.get_test_files('doc/src', '*.md', init_only=False) + test_files = test_files_rst + test_files_md + test_files.sort() + + not_blacklisted = [f for f in test_files + if not any(b in f for b in blacklist)] + + if len(paths) == 0: + matched = not_blacklisted + else: + # Take only what was requested as long as it's not on the blacklist. + # Paths were already made native in *py tests so don't repeat here. + # There's no chance of having a *py file slip through since we + # only have *rst files in test_files. + matched = [] + for f in not_blacklisted: + basename = os.path.basename(f) + for p in paths: + if p in f or fnmatch(basename, p): + matched.append(f) + break + + if split: + matched = split_list(matched, split) + + first_report = True + for rst_file in matched: + if not os.path.isfile(rst_file): + continue + old_displayhook = sys.displayhook + try: + use_unicode_prev, wrap_line_prev = setup_pprint() + out = sympytestfile( + rst_file, module_relative=False, encoding='utf-8', + optionflags=pdoctest.ELLIPSIS | pdoctest.NORMALIZE_WHITESPACE | + pdoctest.IGNORE_EXCEPTION_DETAIL) + finally: + # make sure we return to the original displayhook in case some + # doctest has changed that + sys.displayhook = old_displayhook + # The NO_GLOBAL flag overrides the no_global flag to init_printing + # if True + import sympy.interactive.printing as interactive_printing + interactive_printing.NO_GLOBAL = False + pprint_use_unicode(use_unicode_prev) + stringpict._GLOBAL_WRAP_LINE = wrap_line_prev + + rstfailed, tested = out + if tested: + failed = rstfailed or failed + if first_report: + first_report = False + msg = 'rst/md doctests start' + if not t._testfiles: + r.start(msg=msg) + else: + r.write_center(msg) + print() + # use as the id, everything past the first 'sympy' + file_id = rst_file[rst_file.find('sympy') + len('sympy') + 1:] + print(file_id, end=" ") + # get at least the name out so it is know who is being tested + wid = r.terminal_width - len(file_id) - 1 # update width + test_file = '[%s]' % (tested) + report = '[%s]' % (rstfailed or 'OK') + print(''.join( + [test_file, ' '*(wid - len(test_file) - len(report)), report]) + ) + + # the doctests for *py will have printed this message already if there was + # a failure, so now only print it if there was intervening reporting by + # testing the *rst as evidenced by first_report no longer being True. + if not first_report and failed: + print() + print("DO *NOT* COMMIT!") + + return int(failed) + +sp = re.compile(r'([0-9]+)/([1-9][0-9]*)') + +def split_list(l, split, density=None): + """ + Splits a list into part a of b + + split should be a string of the form 'a/b'. For instance, '1/3' would give + the split one of three. + + If the length of the list is not divisible by the number of splits, the + last split will have more items. + + `density` may be specified as a list. If specified, + tests will be balanced so that each split has as equal-as-possible + amount of mass according to `density`. + + >>> from sympy.testing.runtests import split_list + >>> a = list(range(10)) + >>> split_list(a, '1/3') + [0, 1, 2] + >>> split_list(a, '2/3') + [3, 4, 5] + >>> split_list(a, '3/3') + [6, 7, 8, 9] + """ + m = sp.match(split) + if not m: + raise ValueError("split must be a string of the form a/b where a and b are ints") + i, t = map(int, m.groups()) + + if not density: + return l[(i - 1)*len(l)//t : i*len(l)//t] + + # normalize density + tot = sum(density) + density = [x / tot for x in density] + + def density_inv(x): + """Interpolate the inverse to the cumulative + distribution function given by density""" + if x <= 0: + return 0 + if x >= sum(density): + return 1 + + # find the first time the cumulative sum surpasses x + # and linearly interpolate + cumm = 0 + for i, d in enumerate(density): + cumm += d + if cumm >= x: + break + frac = (d - (cumm - x)) / d + return (i + frac) / len(density) + + lower_frac = density_inv((i - 1) / t) + higher_frac = density_inv(i / t) + return l[int(lower_frac*len(l)) : int(higher_frac*len(l))] + +from collections import namedtuple +SymPyTestResults = namedtuple('SymPyTestResults', 'failed attempted') + +def sympytestfile(filename, module_relative=True, name=None, package=None, + globs=None, verbose=None, report=True, optionflags=0, + extraglobs=None, raise_on_error=False, + parser=pdoctest.DocTestParser(), encoding=None): + + """ + Test examples in the given file. Return (#failures, #tests). + + Optional keyword arg ``module_relative`` specifies how filenames + should be interpreted: + + - If ``module_relative`` is True (the default), then ``filename`` + specifies a module-relative path. By default, this path is + relative to the calling module's directory; but if the + ``package`` argument is specified, then it is relative to that + package. To ensure os-independence, ``filename`` should use + "/" characters to separate path segments, and should not + be an absolute path (i.e., it may not begin with "/"). + + - If ``module_relative`` is False, then ``filename`` specifies an + os-specific path. The path may be absolute or relative (to + the current working directory). + + Optional keyword arg ``name`` gives the name of the test; by default + use the file's basename. + + Optional keyword argument ``package`` is a Python package or the + name of a Python package whose directory should be used as the + base directory for a module relative filename. If no package is + specified, then the calling module's directory is used as the base + directory for module relative filenames. It is an error to + specify ``package`` if ``module_relative`` is False. + + Optional keyword arg ``globs`` gives a dict to be used as the globals + when executing examples; by default, use {}. A copy of this dict + is actually used for each docstring, so that each docstring's + examples start with a clean slate. + + Optional keyword arg ``extraglobs`` gives a dictionary that should be + merged into the globals that are used to execute examples. By + default, no extra globals are used. + + Optional keyword arg ``verbose`` prints lots of stuff if true, prints + only failures if false; by default, it's true iff "-v" is in sys.argv. + + Optional keyword arg ``report`` prints a summary at the end when true, + else prints nothing at the end. In verbose mode, the summary is + detailed, else very brief (in fact, empty if all tests passed). + + Optional keyword arg ``optionflags`` or's together module constants, + and defaults to 0. Possible values (see the docs for details): + + - DONT_ACCEPT_TRUE_FOR_1 + - DONT_ACCEPT_BLANKLINE + - NORMALIZE_WHITESPACE + - ELLIPSIS + - SKIP + - IGNORE_EXCEPTION_DETAIL + - REPORT_UDIFF + - REPORT_CDIFF + - REPORT_NDIFF + - REPORT_ONLY_FIRST_FAILURE + + Optional keyword arg ``raise_on_error`` raises an exception on the + first unexpected exception or failure. This allows failures to be + post-mortem debugged. + + Optional keyword arg ``parser`` specifies a DocTestParser (or + subclass) that should be used to extract tests from the files. + + Optional keyword arg ``encoding`` specifies an encoding that should + be used to convert the file to unicode. + + Advanced tomfoolery: testmod runs methods of a local instance of + class doctest.Tester, then merges the results into (or creates) + global Tester instance doctest.master. Methods of doctest.master + can be called directly too, if you want to do something unusual. + Passing report=0 to testmod is especially useful then, to delay + displaying a summary. Invoke doctest.master.summarize(verbose) + when you're done fiddling. + """ + if package and not module_relative: + raise ValueError("Package may only be specified for module-" + "relative paths.") + + # Relativize the path + text, filename = pdoctest._load_testfile( + filename, package, module_relative, encoding) + + # If no name was given, then use the file's name. + if name is None: + name = os.path.basename(filename) + + # Assemble the globals. + if globs is None: + globs = {} + else: + globs = globs.copy() + if extraglobs is not None: + globs.update(extraglobs) + if '__name__' not in globs: + globs['__name__'] = '__main__' + + if raise_on_error: + runner = pdoctest.DebugRunner(verbose=verbose, optionflags=optionflags) + else: + runner = SymPyDocTestRunner(verbose=verbose, optionflags=optionflags) + runner._checker = SymPyOutputChecker() + + # Read the file, convert it to a test, and run it. + test = parser.get_doctest(text, globs, name, filename, 0) + runner.run(test) + + if report: + runner.summarize() + + if pdoctest.master is None: + pdoctest.master = runner + else: + pdoctest.master.merge(runner) + + return SymPyTestResults(runner.failures, runner.tries) + + +class SymPyTests: + + def __init__(self, reporter, kw="", post_mortem=False, + seed=None, fast_threshold=None, slow_threshold=None): + self._post_mortem = post_mortem + self._kw = kw + self._count = 0 + self._root_dir = get_sympy_dir() + self._reporter = reporter + self._reporter.root_dir(self._root_dir) + self._testfiles = [] + self._seed = seed if seed is not None else random.random() + + # Defaults in seconds, from human / UX design limits + # http://www.nngroup.com/articles/response-times-3-important-limits/ + # + # These defaults are *NOT* set in stone as we are measuring different + # things, so others feel free to come up with a better yardstick :) + if fast_threshold: + self._fast_threshold = float(fast_threshold) + else: + self._fast_threshold = 8 + if slow_threshold: + self._slow_threshold = float(slow_threshold) + else: + self._slow_threshold = 10 + + def test(self, sort=False, timeout=False, slow=False, + enhance_asserts=False, fail_on_timeout=False): + """ + Runs the tests returning True if all tests pass, otherwise False. + + If sort=False run tests in random order. + """ + if sort: + self._testfiles.sort() + elif slow: + pass + else: + random.seed(self._seed) + random.shuffle(self._testfiles) + self._reporter.start(self._seed) + for f in self._testfiles: + try: + self.test_file(f, sort, timeout, slow, + enhance_asserts, fail_on_timeout) + except KeyboardInterrupt: + print(" interrupted by user") + self._reporter.finish() + raise + return self._reporter.finish() + + def _enhance_asserts(self, source): + from ast import (NodeTransformer, Compare, Name, Store, Load, Tuple, + Assign, BinOp, Str, Mod, Assert, parse, fix_missing_locations) + + ops = {"Eq": '==', "NotEq": '!=', "Lt": '<', "LtE": '<=', + "Gt": '>', "GtE": '>=', "Is": 'is', "IsNot": 'is not', + "In": 'in', "NotIn": 'not in'} + + class Transform(NodeTransformer): + def visit_Assert(self, stmt): + if isinstance(stmt.test, Compare): + compare = stmt.test + values = [compare.left] + compare.comparators + names = [ "_%s" % i for i, _ in enumerate(values) ] + names_store = [ Name(n, Store()) for n in names ] + names_load = [ Name(n, Load()) for n in names ] + target = Tuple(names_store, Store()) + value = Tuple(values, Load()) + assign = Assign([target], value) + new_compare = Compare(names_load[0], compare.ops, names_load[1:]) + msg_format = "\n%s " + "\n%s ".join([ ops[op.__class__.__name__] for op in compare.ops ]) + "\n%s" + msg = BinOp(Str(msg_format), Mod(), Tuple(names_load, Load())) + test = Assert(new_compare, msg, lineno=stmt.lineno, col_offset=stmt.col_offset) + return [assign, test] + else: + return stmt + + tree = parse(source) + new_tree = Transform().visit(tree) + return fix_missing_locations(new_tree) + + def test_file(self, filename, sort=True, timeout=False, slow=False, + enhance_asserts=False, fail_on_timeout=False): + reporter = self._reporter + funcs = [] + try: + gl = {'__file__': filename} + try: + open_file = lambda: open(filename, encoding="utf8") + + with open_file() as f: + source = f.read() + if self._kw: + for l in source.splitlines(): + if l.lstrip().startswith('def '): + if any(l.lower().find(k.lower()) != -1 for k in self._kw): + break + else: + return + + if enhance_asserts: + try: + source = self._enhance_asserts(source) + except ImportError: + pass + + code = compile(source, filename, "exec", flags=0, dont_inherit=True) + exec(code, gl) + except (SystemExit, KeyboardInterrupt): + raise + except ImportError: + reporter.import_error(filename, sys.exc_info()) + return + except Exception: + reporter.test_exception(sys.exc_info()) + + clear_cache() + self._count += 1 + random.seed(self._seed) + disabled = gl.get("disabled", False) + if not disabled: + # we need to filter only those functions that begin with 'test_' + # We have to be careful about decorated functions. As long as + # the decorator uses functools.wraps, we can detect it. + funcs = [] + for f in gl: + if (f.startswith("test_") and (inspect.isfunction(gl[f]) + or inspect.ismethod(gl[f]))): + func = gl[f] + # Handle multiple decorators + while hasattr(func, '__wrapped__'): + func = func.__wrapped__ + + if inspect.getsourcefile(func) == filename: + funcs.append(gl[f]) + if slow: + funcs = [f for f in funcs if getattr(f, '_slow', False)] + # Sorting of XFAILed functions isn't fixed yet :-( + funcs.sort(key=lambda x: inspect.getsourcelines(x)[1]) + i = 0 + while i < len(funcs): + if inspect.isgeneratorfunction(funcs[i]): + # some tests can be generators, that return the actual + # test functions. We unpack it below: + f = funcs.pop(i) + for fg in f(): + func = fg[0] + args = fg[1:] + fgw = lambda: func(*args) + funcs.insert(i, fgw) + i += 1 + else: + i += 1 + # drop functions that are not selected with the keyword expression: + funcs = [x for x in funcs if self.matches(x)] + + if not funcs: + return + except Exception: + reporter.entering_filename(filename, len(funcs)) + raise + + reporter.entering_filename(filename, len(funcs)) + if not sort: + random.shuffle(funcs) + + for f in funcs: + start = time.time() + reporter.entering_test(f) + try: + if getattr(f, '_slow', False) and not slow: + raise Skipped("Slow") + with raise_on_deprecated(): + if timeout: + self._timeout(f, timeout, fail_on_timeout) + else: + random.seed(self._seed) + f() + except KeyboardInterrupt: + if getattr(f, '_slow', False): + reporter.test_skip("KeyboardInterrupt") + else: + raise + except Exception: + if timeout: + signal.alarm(0) # Disable the alarm. It could not be handled before. + t, v, tr = sys.exc_info() + if t is AssertionError: + reporter.test_fail((t, v, tr)) + if self._post_mortem: + pdb.post_mortem(tr) + elif t.__name__ == "Skipped": + reporter.test_skip(v) + elif t.__name__ == "XFail": + reporter.test_xfail() + elif t.__name__ == "XPass": + reporter.test_xpass(v) + else: + reporter.test_exception((t, v, tr)) + if self._post_mortem: + pdb.post_mortem(tr) + else: + reporter.test_pass() + taken = time.time() - start + if taken > self._slow_threshold: + filename = os.path.relpath(filename, reporter._root_dir) + reporter.slow_test_functions.append( + (filename + "::" + f.__name__, taken)) + if getattr(f, '_slow', False) and slow: + if taken < self._fast_threshold: + filename = os.path.relpath(filename, reporter._root_dir) + reporter.fast_test_functions.append( + (filename + "::" + f.__name__, taken)) + reporter.leaving_filename() + + def _timeout(self, function, timeout, fail_on_timeout): + def callback(x, y): + signal.alarm(0) + if fail_on_timeout: + raise TimeOutError("Timed out after %d seconds" % timeout) + else: + raise Skipped("Timeout") + signal.signal(signal.SIGALRM, callback) + signal.alarm(timeout) # Set an alarm with a given timeout + function() + signal.alarm(0) # Disable the alarm + + def matches(self, x): + """ + Does the keyword expression self._kw match "x"? Returns True/False. + + Always returns True if self._kw is "". + """ + if not self._kw: + return True + for kw in self._kw: + if x.__name__.lower().find(kw.lower()) != -1: + return True + return False + + def get_test_files(self, dir, pat='test_*.py'): + """ + Returns the list of test_*.py (default) files at or below directory + ``dir`` relative to the SymPy home directory. + """ + dir = os.path.join(self._root_dir, convert_to_native_paths([dir])[0]) + + g = [] + for path, folders, files in os.walk(dir): + g.extend([os.path.join(path, f) for f in files if fnmatch(f, pat)]) + + return sorted([os.path.normcase(gi) for gi in g]) + + +class SymPyDocTests: + + def __init__(self, reporter, normal): + self._count = 0 + self._root_dir = get_sympy_dir() + self._reporter = reporter + self._reporter.root_dir(self._root_dir) + self._normal = normal + + self._testfiles = [] + + def test(self): + """ + Runs the tests and returns True if all tests pass, otherwise False. + """ + self._reporter.start() + for f in self._testfiles: + try: + self.test_file(f) + except KeyboardInterrupt: + print(" interrupted by user") + self._reporter.finish() + raise + return self._reporter.finish() + + def test_file(self, filename): + clear_cache() + + from io import StringIO + import sympy.interactive.printing as interactive_printing + from sympy.printing.pretty.pretty import pprint_use_unicode + from sympy.printing.pretty import stringpict + + rel_name = filename[len(self._root_dir) + 1:] + dirname, file = os.path.split(filename) + module = rel_name.replace(os.sep, '.')[:-3] + + if rel_name.startswith("examples"): + # Examples files do not have __init__.py files, + # So we have to temporarily extend sys.path to import them + sys.path.insert(0, dirname) + module = file[:-3] # remove ".py" + try: + module = pdoctest._normalize_module(module) + tests = SymPyDocTestFinder().find(module) + except (SystemExit, KeyboardInterrupt): + raise + except ImportError: + self._reporter.import_error(filename, sys.exc_info()) + return + finally: + if rel_name.startswith("examples"): + del sys.path[0] + + tests = [test for test in tests if len(test.examples) > 0] + # By default tests are sorted by alphabetical order by function name. + # We sort by line number so one can edit the file sequentially from + # bottom to top. However, if there are decorated functions, their line + # numbers will be too large and for now one must just search for these + # by text and function name. + tests.sort(key=lambda x: -x.lineno) + + if not tests: + return + self._reporter.entering_filename(filename, len(tests)) + for test in tests: + assert len(test.examples) != 0 + + if self._reporter._verbose: + self._reporter.write("\n{} ".format(test.name)) + + # check if there are external dependencies which need to be met + if '_doctest_depends_on' in test.globs: + try: + self._check_dependencies(**test.globs['_doctest_depends_on']) + except DependencyError as e: + self._reporter.test_skip(v=str(e)) + continue + + runner = SymPyDocTestRunner(verbose=self._reporter._verbose==2, + optionflags=pdoctest.ELLIPSIS | + pdoctest.NORMALIZE_WHITESPACE | + pdoctest.IGNORE_EXCEPTION_DETAIL) + runner._checker = SymPyOutputChecker() + old = sys.stdout + new = old if self._reporter._verbose==2 else StringIO() + sys.stdout = new + # If the testing is normal, the doctests get importing magic to + # provide the global namespace. If not normal (the default) then + # then must run on their own; all imports must be explicit within + # a function's docstring. Once imported that import will be + # available to the rest of the tests in a given function's + # docstring (unless clear_globs=True below). + if not self._normal: + test.globs = {} + # if this is uncommented then all the test would get is what + # comes by default with a "from sympy import *" + #exec('from sympy import *') in test.globs + old_displayhook = sys.displayhook + use_unicode_prev, wrap_line_prev = setup_pprint() + + try: + f, t = runner.run(test, + out=new.write, clear_globs=False) + except KeyboardInterrupt: + raise + finally: + sys.stdout = old + if f > 0: + self._reporter.doctest_fail(test.name, new.getvalue()) + else: + self._reporter.test_pass() + sys.displayhook = old_displayhook + interactive_printing.NO_GLOBAL = False + pprint_use_unicode(use_unicode_prev) + stringpict._GLOBAL_WRAP_LINE = wrap_line_prev + + self._reporter.leaving_filename() + + def get_test_files(self, dir, pat='*.py', init_only=True): + r""" + Returns the list of \*.py files (default) from which docstrings + will be tested which are at or below directory ``dir``. By default, + only those that have an __init__.py in their parent directory + and do not start with ``test_`` will be included. + """ + def importable(x): + """ + Checks if given pathname x is an importable module by checking for + __init__.py file. + + Returns True/False. + + Currently we only test if the __init__.py file exists in the + directory with the file "x" (in theory we should also test all the + parent dirs). + """ + init_py = os.path.join(os.path.dirname(x), "__init__.py") + return os.path.exists(init_py) + + dir = os.path.join(self._root_dir, convert_to_native_paths([dir])[0]) + + g = [] + for path, folders, files in os.walk(dir): + g.extend([os.path.join(path, f) for f in files + if not f.startswith('test_') and fnmatch(f, pat)]) + if init_only: + # skip files that are not importable (i.e. missing __init__.py) + g = [x for x in g if importable(x)] + + return [os.path.normcase(gi) for gi in g] + + def _check_dependencies(self, + executables=(), + modules=(), + disable_viewers=(), + python_version=(3, 5), + ground_types=None): + """ + Checks if the dependencies for the test are installed. + + Raises ``DependencyError`` it at least one dependency is not installed. + """ + + for executable in executables: + if not shutil.which(executable): + raise DependencyError("Could not find %s" % executable) + + for module in modules: + if module == 'matplotlib': + matplotlib = import_module( + 'matplotlib', + import_kwargs={'fromlist': + ['pyplot', 'cm', 'collections']}, + min_module_version='1.0.0', catch=(RuntimeError,)) + if matplotlib is None: + raise DependencyError("Could not import matplotlib") + else: + if not import_module(module): + raise DependencyError("Could not import %s" % module) + + if disable_viewers: + tempdir = tempfile.mkdtemp() + os.environ['PATH'] = '%s:%s' % (tempdir, os.environ['PATH']) + + vw = ('#!/usr/bin/env python3\n' + 'import sys\n' + 'if len(sys.argv) <= 1:\n' + ' exit("wrong number of args")\n') + + for viewer in disable_viewers: + Path(os.path.join(tempdir, viewer)).write_text(vw) + + # make the file executable + os.chmod(os.path.join(tempdir, viewer), + stat.S_IREAD | stat.S_IWRITE | stat.S_IXUSR) + + if python_version: + if sys.version_info < python_version: + raise DependencyError("Requires Python >= " + '.'.join(map(str, python_version))) + + if ground_types is not None: + if GROUND_TYPES not in ground_types: + raise DependencyError("Requires ground_types in " + str(ground_types)) + + if 'pyglet' in modules: + # monkey-patch pyglet s.t. it does not open a window during + # doctesting + import pyglet + class DummyWindow: + def __init__(self, *args, **kwargs): + self.has_exit = True + self.width = 600 + self.height = 400 + + def set_vsync(self, x): + pass + + def switch_to(self): + pass + + def push_handlers(self, x): + pass + + def close(self): + pass + + pyglet.window.Window = DummyWindow + + +class SymPyDocTestFinder(DocTestFinder): + """ + A class used to extract the DocTests that are relevant to a given + object, from its docstring and the docstrings of its contained + objects. Doctests can currently be extracted from the following + object types: modules, functions, classes, methods, staticmethods, + classmethods, and properties. + + Modified from doctest's version to look harder for code that + appears comes from a different module. For example, the @vectorize + decorator makes it look like functions come from multidimensional.py + even though their code exists elsewhere. + """ + + def _find(self, tests, obj, name, module, source_lines, globs, seen): + """ + Find tests for the given object and any contained objects, and + add them to ``tests``. + """ + if self._verbose: + print('Finding tests in %s' % name) + + # If we've already processed this object, then ignore it. + if id(obj) in seen: + return + seen[id(obj)] = 1 + + # Make sure we don't run doctests for classes outside of sympy, such + # as in numpy or scipy. + if inspect.isclass(obj): + if obj.__module__.split('.')[0] != 'sympy': + return + + # Find a test for this object, and add it to the list of tests. + test = self._get_test(obj, name, module, globs, source_lines) + if test is not None: + tests.append(test) + + if not self._recurse: + return + + # Look for tests in a module's contained objects. + if inspect.ismodule(obj): + for rawname, val in obj.__dict__.items(): + # Recurse to functions & classes. + if inspect.isfunction(val) or inspect.isclass(val): + # Make sure we don't run doctests functions or classes + # from different modules + if val.__module__ != module.__name__: + continue + + assert self._from_module(module, val), \ + "%s is not in module %s (rawname %s)" % (val, module, rawname) + + try: + valname = '%s.%s' % (name, rawname) + self._find(tests, val, valname, module, + source_lines, globs, seen) + except KeyboardInterrupt: + raise + + # Look for tests in a module's __test__ dictionary. + for valname, val in getattr(obj, '__test__', {}).items(): + if not isinstance(valname, str): + raise ValueError("SymPyDocTestFinder.find: __test__ keys " + "must be strings: %r" % + (type(valname),)) + if not (inspect.isfunction(val) or inspect.isclass(val) or + inspect.ismethod(val) or inspect.ismodule(val) or + isinstance(val, str)): + raise ValueError("SymPyDocTestFinder.find: __test__ values " + "must be strings, functions, methods, " + "classes, or modules: %r" % + (type(val),)) + valname = '%s.__test__.%s' % (name, valname) + self._find(tests, val, valname, module, source_lines, + globs, seen) + + + # Look for tests in a class's contained objects. + if inspect.isclass(obj): + for valname, val in obj.__dict__.items(): + # Special handling for staticmethod/classmethod. + if isinstance(val, staticmethod): + val = getattr(obj, valname) + if isinstance(val, classmethod): + val = getattr(obj, valname).__func__ + + + # Recurse to methods, properties, and nested classes. + if ((inspect.isfunction(unwrap(val)) or + inspect.isclass(val) or + isinstance(val, property)) and + self._from_module(module, val)): + # Make sure we don't run doctests functions or classes + # from different modules + if isinstance(val, property): + if hasattr(val.fget, '__module__'): + if val.fget.__module__ != module.__name__: + continue + else: + if val.__module__ != module.__name__: + continue + + assert self._from_module(module, val), \ + "%s is not in module %s (valname %s)" % ( + val, module, valname) + + valname = '%s.%s' % (name, valname) + self._find(tests, val, valname, module, source_lines, + globs, seen) + + def _get_test(self, obj, name, module, globs, source_lines): + """ + Return a DocTest for the given object, if it defines a docstring; + otherwise, return None. + """ + + lineno = None + + # Extract the object's docstring. If it does not have one, + # then return None (no test for this object). + if isinstance(obj, str): + # obj is a string in the case for objects in the polys package. + # Note that source_lines is a binary string (compiled polys + # modules), which can't be handled by _find_lineno so determine + # the line number here. + + docstring = obj + + matches = re.findall(r"line \d+", name) + assert len(matches) == 1, \ + "string '%s' does not contain lineno " % name + + # NOTE: this is not the exact linenumber but its better than no + # lineno ;) + lineno = int(matches[0][5:]) + + else: + docstring = getattr(obj, '__doc__', '') + if docstring is None: + docstring = '' + if not isinstance(docstring, str): + docstring = str(docstring) + + # Don't bother if the docstring is empty. + if self._exclude_empty and not docstring: + return None + + # check that properties have a docstring because _find_lineno + # assumes it + if isinstance(obj, property): + if obj.fget.__doc__ is None: + return None + + # Find the docstring's location in the file. + if lineno is None: + obj = unwrap(obj) + # handling of properties is not implemented in _find_lineno so do + # it here + if hasattr(obj, 'func_closure') and obj.func_closure is not None: + tobj = obj.func_closure[0].cell_contents + elif isinstance(obj, property): + tobj = obj.fget + else: + tobj = obj + lineno = self._find_lineno(tobj, source_lines) + + if lineno is None: + return None + + # Return a DocTest for this object. + if module is None: + filename = None + else: + filename = getattr(module, '__file__', module.__name__) + if filename[-4:] in (".pyc", ".pyo"): + filename = filename[:-1] + + globs['_doctest_depends_on'] = getattr(obj, '_doctest_depends_on', {}) + + return self._parser.get_doctest(docstring, globs, name, + filename, lineno) + + +class SymPyDocTestRunner(DocTestRunner): + """ + A class used to run DocTest test cases, and accumulate statistics. + The ``run`` method is used to process a single DocTest case. It + returns a tuple ``(f, t)``, where ``t`` is the number of test cases + tried, and ``f`` is the number of test cases that failed. + + Modified from the doctest version to not reset the sys.displayhook (see + issue 5140). + + See the docstring of the original DocTestRunner for more information. + """ + + def run(self, test, compileflags=None, out=None, clear_globs=True): + """ + Run the examples in ``test``, and display the results using the + writer function ``out``. + + The examples are run in the namespace ``test.globs``. If + ``clear_globs`` is true (the default), then this namespace will + be cleared after the test runs, to help with garbage + collection. If you would like to examine the namespace after + the test completes, then use ``clear_globs=False``. + + ``compileflags`` gives the set of flags that should be used by + the Python compiler when running the examples. If not + specified, then it will default to the set of future-import + flags that apply to ``globs``. + + The output of each example is checked using + ``SymPyDocTestRunner.check_output``, and the results are + formatted by the ``SymPyDocTestRunner.report_*`` methods. + """ + self.test = test + + # Remove ``` from the end of example, which may appear in Markdown + # files + for example in test.examples: + example.want = example.want.replace('```\n', '') + example.exc_msg = example.exc_msg and example.exc_msg.replace('```\n', '') + + + if compileflags is None: + compileflags = pdoctest._extract_future_flags(test.globs) + + save_stdout = sys.stdout + if out is None: + out = save_stdout.write + sys.stdout = self._fakeout + + # Patch pdb.set_trace to restore sys.stdout during interactive + # debugging (so it's not still redirected to self._fakeout). + # Note that the interactive output will go to *our* + # save_stdout, even if that's not the real sys.stdout; this + # allows us to write test cases for the set_trace behavior. + save_set_trace = pdb.set_trace + self.debugger = pdoctest._OutputRedirectingPdb(save_stdout) + self.debugger.reset() + pdb.set_trace = self.debugger.set_trace + + # Patch linecache.getlines, so we can see the example's source + # when we're inside the debugger. + self.save_linecache_getlines = pdoctest.linecache.getlines + linecache.getlines = self.__patched_linecache_getlines + + # Fail for deprecation warnings + with raise_on_deprecated(): + try: + return self.__run(test, compileflags, out) + finally: + sys.stdout = save_stdout + pdb.set_trace = save_set_trace + linecache.getlines = self.save_linecache_getlines + if clear_globs: + test.globs.clear() + + +# We have to override the name mangled methods. +monkeypatched_methods = [ + 'patched_linecache_getlines', + 'run', + 'record_outcome' +] +for method in monkeypatched_methods: + oldname = '_DocTestRunner__' + method + newname = '_SymPyDocTestRunner__' + method + setattr(SymPyDocTestRunner, newname, getattr(DocTestRunner, oldname)) + + +class SymPyOutputChecker(pdoctest.OutputChecker): + """ + Compared to the OutputChecker from the stdlib our OutputChecker class + supports numerical comparison of floats occurring in the output of the + doctest examples + """ + + def __init__(self): + # NOTE OutputChecker is an old-style class with no __init__ method, + # so we can't call the base class version of __init__ here + + got_floats = r'(\d+\.\d*|\.\d+)' + + # floats in the 'want' string may contain ellipses + want_floats = got_floats + r'(\.{3})?' + + front_sep = r'\s|\+|\-|\*|,' + back_sep = front_sep + r'|j|e' + + fbeg = r'^%s(?=%s|$)' % (got_floats, back_sep) + fmidend = r'(?<=%s)%s(?=%s|$)' % (front_sep, got_floats, back_sep) + self.num_got_rgx = re.compile(r'(%s|%s)' %(fbeg, fmidend)) + + fbeg = r'^%s(?=%s|$)' % (want_floats, back_sep) + fmidend = r'(?<=%s)%s(?=%s|$)' % (front_sep, want_floats, back_sep) + self.num_want_rgx = re.compile(r'(%s|%s)' %(fbeg, fmidend)) + + def check_output(self, want, got, optionflags): + """ + Return True iff the actual output from an example (`got`) + matches the expected output (`want`). These strings are + always considered to match if they are identical; but + depending on what option flags the test runner is using, + several non-exact match types are also possible. See the + documentation for `TestRunner` for more information about + option flags. + """ + # Handle the common case first, for efficiency: + # if they're string-identical, always return true. + if got == want: + return True + + # TODO parse integers as well ? + # Parse floats and compare them. If some of the parsed floats contain + # ellipses, skip the comparison. + matches = self.num_got_rgx.finditer(got) + numbers_got = [match.group(1) for match in matches] # list of strs + matches = self.num_want_rgx.finditer(want) + numbers_want = [match.group(1) for match in matches] # list of strs + if len(numbers_got) != len(numbers_want): + return False + + if len(numbers_got) > 0: + nw_ = [] + for ng, nw in zip(numbers_got, numbers_want): + if '...' in nw: + nw_.append(ng) + continue + else: + nw_.append(nw) + + if abs(float(ng)-float(nw)) > 1e-5: + return False + + got = self.num_got_rgx.sub(r'%s', got) + got = got % tuple(nw_) + + # can be used as a special sequence to signify a + # blank line, unless the DONT_ACCEPT_BLANKLINE flag is used. + if not (optionflags & pdoctest.DONT_ACCEPT_BLANKLINE): + # Replace in want with a blank line. + want = re.sub(r'(?m)^%s\s*?$' % re.escape(pdoctest.BLANKLINE_MARKER), + '', want) + # If a line in got contains only spaces, then remove the + # spaces. + got = re.sub(r'(?m)^\s*?$', '', got) + if got == want: + return True + + # This flag causes doctest to ignore any differences in the + # contents of whitespace strings. Note that this can be used + # in conjunction with the ELLIPSIS flag. + if optionflags & pdoctest.NORMALIZE_WHITESPACE: + got = ' '.join(got.split()) + want = ' '.join(want.split()) + if got == want: + return True + + # The ELLIPSIS flag says to let the sequence "..." in `want` + # match any substring in `got`. + if optionflags & pdoctest.ELLIPSIS: + if pdoctest._ellipsis_match(want, got): + return True + + # We didn't find any match; return false. + return False + + +class Reporter: + """ + Parent class for all reporters. + """ + pass + + +class PyTestReporter(Reporter): + """ + Py.test like reporter. Should produce output identical to py.test. + """ + + def __init__(self, verbose=False, tb="short", colors=True, + force_colors=False, split=None): + self._verbose = verbose + self._tb_style = tb + self._colors = colors + self._force_colors = force_colors + self._xfailed = 0 + self._xpassed = [] + self._failed = [] + self._failed_doctest = [] + self._passed = 0 + self._skipped = 0 + self._exceptions = [] + self._terminal_width = None + self._default_width = 80 + self._split = split + self._active_file = '' + self._active_f = None + + # TODO: Should these be protected? + self.slow_test_functions = [] + self.fast_test_functions = [] + + # this tracks the x-position of the cursor (useful for positioning + # things on the screen), without the need for any readline library: + self._write_pos = 0 + self._line_wrap = False + + def root_dir(self, dir): + self._root_dir = dir + + @property + def terminal_width(self): + if self._terminal_width is not None: + return self._terminal_width + + def findout_terminal_width(): + if sys.platform == "win32": + # Windows support is based on: + # + # http://code.activestate.com/recipes/ + # 440694-determine-size-of-console-window-on-windows/ + + from ctypes import windll, create_string_buffer + + h = windll.kernel32.GetStdHandle(-12) + csbi = create_string_buffer(22) + res = windll.kernel32.GetConsoleScreenBufferInfo(h, csbi) + + if res: + import struct + (_, _, _, _, _, left, _, right, _, _, _) = \ + struct.unpack("hhhhHhhhhhh", csbi.raw) + return right - left + else: + return self._default_width + + if hasattr(sys.stdout, 'isatty') and not sys.stdout.isatty(): + return self._default_width # leave PIPEs alone + + try: + process = subprocess.Popen(['stty', '-a'], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + stdout = stdout.decode("utf-8") + except OSError: + pass + else: + # We support the following output formats from stty: + # + # 1) Linux -> columns 80 + # 2) OS X -> 80 columns + # 3) Solaris -> columns = 80 + + re_linux = r"columns\s+(?P\d+);" + re_osx = r"(?P\d+)\s*columns;" + re_solaris = r"columns\s+=\s+(?P\d+);" + + for regex in (re_linux, re_osx, re_solaris): + match = re.search(regex, stdout) + + if match is not None: + columns = match.group('columns') + + try: + width = int(columns) + except ValueError: + pass + if width != 0: + return width + + return self._default_width + + width = findout_terminal_width() + self._terminal_width = width + + return width + + def write(self, text, color="", align="left", width=None, + force_colors=False): + """ + Prints a text on the screen. + + It uses sys.stdout.write(), so no readline library is necessary. + + Parameters + ========== + + color : choose from the colors below, "" means default color + align : "left"/"right", "left" is a normal print, "right" is aligned on + the right-hand side of the screen, filled with spaces if + necessary + width : the screen width + + """ + color_templates = ( + ("Black", "0;30"), + ("Red", "0;31"), + ("Green", "0;32"), + ("Brown", "0;33"), + ("Blue", "0;34"), + ("Purple", "0;35"), + ("Cyan", "0;36"), + ("LightGray", "0;37"), + ("DarkGray", "1;30"), + ("LightRed", "1;31"), + ("LightGreen", "1;32"), + ("Yellow", "1;33"), + ("LightBlue", "1;34"), + ("LightPurple", "1;35"), + ("LightCyan", "1;36"), + ("White", "1;37"), + ) + + colors = {} + + for name, value in color_templates: + colors[name] = value + c_normal = '\033[0m' + c_color = '\033[%sm' + + if width is None: + width = self.terminal_width + + if align == "right": + if self._write_pos + len(text) > width: + # we don't fit on the current line, create a new line + self.write("\n") + self.write(" "*(width - self._write_pos - len(text))) + + if not self._force_colors and hasattr(sys.stdout, 'isatty') and not \ + sys.stdout.isatty(): + # the stdout is not a terminal, this for example happens if the + # output is piped to less, e.g. "bin/test | less". In this case, + # the terminal control sequences would be printed verbatim, so + # don't use any colors. + color = "" + elif sys.platform == "win32": + # Windows consoles don't support ANSI escape sequences + color = "" + elif not self._colors: + color = "" + + if self._line_wrap: + if text[0] != "\n": + sys.stdout.write("\n") + + # Avoid UnicodeEncodeError when printing out test failures + if IS_WINDOWS: + text = text.encode('raw_unicode_escape').decode('utf8', 'ignore') + elif not sys.stdout.encoding.lower().startswith('utf'): + text = text.encode(sys.stdout.encoding, 'backslashreplace' + ).decode(sys.stdout.encoding) + + if color == "": + sys.stdout.write(text) + else: + sys.stdout.write("%s%s%s" % + (c_color % colors[color], text, c_normal)) + sys.stdout.flush() + l = text.rfind("\n") + if l == -1: + self._write_pos += len(text) + else: + self._write_pos = len(text) - l - 1 + self._line_wrap = self._write_pos >= width + self._write_pos %= width + + def write_center(self, text, delim="="): + width = self.terminal_width + if text != "": + text = " %s " % text + idx = (width - len(text)) // 2 + t = delim*idx + text + delim*(width - idx - len(text)) + self.write(t + "\n") + + def write_exception(self, e, val, tb): + # remove the first item, as that is always runtests.py + tb = tb.tb_next + t = traceback.format_exception(e, val, tb) + self.write("".join(t)) + + def start(self, seed=None, msg="test process starts"): + self.write_center(msg) + executable = sys.executable + v = tuple(sys.version_info) + python_version = "%s.%s.%s-%s-%s" % v + implementation = platform.python_implementation() + if implementation == 'PyPy': + implementation += " %s.%s.%s-%s-%s" % sys.pypy_version_info + self.write("executable: %s (%s) [%s]\n" % + (executable, python_version, implementation)) + from sympy.utilities.misc import ARCH + self.write("architecture: %s\n" % ARCH) + from sympy.core.cache import USE_CACHE + self.write("cache: %s\n" % USE_CACHE) + version = '' + if GROUND_TYPES =='gmpy': + import gmpy2 as gmpy + version = gmpy.version() + self.write("ground types: %s %s\n" % (GROUND_TYPES, version)) + numpy = import_module('numpy') + self.write("numpy: %s\n" % (None if not numpy else numpy.__version__)) + if seed is not None: + self.write("random seed: %d\n" % seed) + from sympy.utilities.misc import HASH_RANDOMIZATION + self.write("hash randomization: ") + hash_seed = os.getenv("PYTHONHASHSEED") or '0' + if HASH_RANDOMIZATION and (hash_seed == "random" or int(hash_seed)): + self.write("on (PYTHONHASHSEED=%s)\n" % hash_seed) + else: + self.write("off\n") + if self._split: + self.write("split: %s\n" % self._split) + self.write('\n') + self._t_start = clock() + + def finish(self): + self._t_end = clock() + self.write("\n") + global text, linelen + text = "tests finished: %d passed, " % self._passed + linelen = len(text) + + def add_text(mytext): + global text, linelen + """Break new text if too long.""" + if linelen + len(mytext) > self.terminal_width: + text += '\n' + linelen = 0 + text += mytext + linelen += len(mytext) + + if len(self._failed) > 0: + add_text("%d failed, " % len(self._failed)) + if len(self._failed_doctest) > 0: + add_text("%d failed, " % len(self._failed_doctest)) + if self._skipped > 0: + add_text("%d skipped, " % self._skipped) + if self._xfailed > 0: + add_text("%d expected to fail, " % self._xfailed) + if len(self._xpassed) > 0: + add_text("%d expected to fail but passed, " % len(self._xpassed)) + if len(self._exceptions) > 0: + add_text("%d exceptions, " % len(self._exceptions)) + add_text("in %.2f seconds" % (self._t_end - self._t_start)) + + if self.slow_test_functions: + self.write_center('slowest tests', '_') + sorted_slow = sorted(self.slow_test_functions, key=lambda r: r[1]) + for slow_func_name, taken in sorted_slow: + print('%s - Took %.3f seconds' % (slow_func_name, taken)) + + if self.fast_test_functions: + self.write_center('unexpectedly fast tests', '_') + sorted_fast = sorted(self.fast_test_functions, + key=lambda r: r[1]) + for fast_func_name, taken in sorted_fast: + print('%s - Took %.3f seconds' % (fast_func_name, taken)) + + if len(self._xpassed) > 0: + self.write_center("xpassed tests", "_") + for e in self._xpassed: + self.write("%s: %s\n" % (e[0], e[1])) + self.write("\n") + + if self._tb_style != "no" and len(self._exceptions) > 0: + for e in self._exceptions: + filename, f, (t, val, tb) = e + self.write_center("", "_") + if f is None: + s = "%s" % filename + else: + s = "%s:%s" % (filename, f.__name__) + self.write_center(s, "_") + self.write_exception(t, val, tb) + self.write("\n") + + if self._tb_style != "no" and len(self._failed) > 0: + for e in self._failed: + filename, f, (t, val, tb) = e + self.write_center("", "_") + self.write_center("%s::%s" % (filename, f.__name__), "_") + self.write_exception(t, val, tb) + self.write("\n") + + if self._tb_style != "no" and len(self._failed_doctest) > 0: + for e in self._failed_doctest: + filename, msg = e + self.write_center("", "_") + self.write_center("%s" % filename, "_") + self.write(msg) + self.write("\n") + + self.write_center(text) + ok = len(self._failed) == 0 and len(self._exceptions) == 0 and \ + len(self._failed_doctest) == 0 + if not ok: + self.write("DO *NOT* COMMIT!\n") + return ok + + def entering_filename(self, filename, n): + rel_name = filename[len(self._root_dir) + 1:] + self._active_file = rel_name + self._active_file_error = False + self.write(rel_name) + self.write("[%d] " % n) + + def leaving_filename(self): + self.write(" ") + if self._active_file_error: + self.write("[FAIL]", "Red", align="right") + else: + self.write("[OK]", "Green", align="right") + self.write("\n") + if self._verbose: + self.write("\n") + + def entering_test(self, f): + self._active_f = f + if self._verbose: + self.write("\n" + f.__name__ + " ") + + def test_xfail(self): + self._xfailed += 1 + self.write("f", "Green") + + def test_xpass(self, v): + message = str(v) + self._xpassed.append((self._active_file, message)) + self.write("X", "Green") + + def test_fail(self, exc_info): + self._failed.append((self._active_file, self._active_f, exc_info)) + self.write("F", "Red") + self._active_file_error = True + + def doctest_fail(self, name, error_msg): + # the first line contains "******", remove it: + error_msg = "\n".join(error_msg.split("\n")[1:]) + self._failed_doctest.append((name, error_msg)) + self.write("F", "Red") + self._active_file_error = True + + def test_pass(self, char="."): + self._passed += 1 + if self._verbose: + self.write("ok", "Green") + else: + self.write(char, "Green") + + def test_skip(self, v=None): + char = "s" + self._skipped += 1 + if v is not None: + message = str(v) + if message == "KeyboardInterrupt": + char = "K" + elif message == "Timeout": + char = "T" + elif message == "Slow": + char = "w" + if self._verbose: + if v is not None: + self.write(message + ' ', "Blue") + else: + self.write(" - ", "Blue") + self.write(char, "Blue") + + def test_exception(self, exc_info): + self._exceptions.append((self._active_file, self._active_f, exc_info)) + if exc_info[0] is TimeOutError: + self.write("T", "Red") + else: + self.write("E", "Red") + self._active_file_error = True + + def import_error(self, filename, exc_info): + self._exceptions.append((filename, None, exc_info)) + rel_name = filename[len(self._root_dir) + 1:] + self.write(rel_name) + self.write("[?] Failed to import", "Red") + self.write(" ") + self.write("[FAIL]", "Red", align="right") + self.write("\n") diff --git a/.venv/lib/python3.13/site-packages/sympy/testing/runtests_pytest.py b/.venv/lib/python3.13/site-packages/sympy/testing/runtests_pytest.py new file mode 100644 index 0000000000000000000000000000000000000000..635f27864ca86571128e6c9a055199dfbde1ed63 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/testing/runtests_pytest.py @@ -0,0 +1,461 @@ +"""Backwards compatible functions for running tests from SymPy using pytest. + +SymPy historically had its own testing framework that aimed to: +- be compatible with pytest; +- operate similarly (or identically) to pytest; +- not require any external dependencies; +- have all the functionality in one file only; +- have no magic, just import the test file and execute the test functions; and +- be portable. + +To reduce the maintenance burden of developing an independent testing framework +and to leverage the benefits of existing Python testing infrastructure, SymPy +now uses pytest (and various of its plugins) to run the test suite. + +To maintain backwards compatibility with the legacy testing interface of SymPy, +which implemented functions that allowed users to run the tests on their +installed version of SymPy, the functions in this module are implemented to +match the existing API while thinly wrapping pytest. + +These two key functions are `test` and `doctest`. + +""" + +import functools +import importlib.util +import os +import pathlib +import re +from fnmatch import fnmatch +from typing import List, Optional, Tuple + +try: + import pytest +except ImportError: + + class NoPytestError(Exception): + """Raise when an internal test helper function is called with pytest.""" + + class pytest: # type: ignore + """Shadow to support pytest features when pytest can't be imported.""" + + @staticmethod + def main(*args, **kwargs): + msg = 'pytest must be installed to run tests via this function' + raise NoPytestError(msg) + +from sympy.testing.runtests import test as test_sympy + + +TESTPATHS_DEFAULT = ( + pathlib.Path('sympy'), + pathlib.Path('doc', 'src'), +) +BLACKLIST_DEFAULT = ( + 'sympy/integrals/rubi/rubi_tests/tests', +) + + +class PytestPluginManager: + """Module names for pytest plugins used by SymPy.""" + PYTEST: str = 'pytest' + RANDOMLY: str = 'pytest_randomly' + SPLIT: str = 'pytest_split' + TIMEOUT: str = 'pytest_timeout' + XDIST: str = 'xdist' + + @functools.cached_property + def has_pytest(self) -> bool: + return bool(importlib.util.find_spec(self.PYTEST)) + + @functools.cached_property + def has_randomly(self) -> bool: + return bool(importlib.util.find_spec(self.RANDOMLY)) + + @functools.cached_property + def has_split(self) -> bool: + return bool(importlib.util.find_spec(self.SPLIT)) + + @functools.cached_property + def has_timeout(self) -> bool: + return bool(importlib.util.find_spec(self.TIMEOUT)) + + @functools.cached_property + def has_xdist(self) -> bool: + return bool(importlib.util.find_spec(self.XDIST)) + + +split_pattern = re.compile(r'([1-9][0-9]*)/([1-9][0-9]*)') + + +@functools.lru_cache +def sympy_dir() -> pathlib.Path: + """Returns the root SymPy directory.""" + return pathlib.Path(__file__).parents[2] + + +def update_args_with_paths( + paths: List[str], + keywords: Optional[Tuple[str]], + args: List[str], +) -> List[str]: + """Appends valid paths and flags to the args `list` passed to `pytest.main`. + + The are three different types of "path" that a user may pass to the `paths` + positional arguments, all of which need to be handled slightly differently: + + 1. Nothing is passed + The paths to the `testpaths` defined in `pytest.ini` need to be appended + to the arguments list. + 2. Full, valid paths are passed + These paths need to be validated but can then be directly appended to + the arguments list. + 3. Partial paths are passed. + The `testpaths` defined in `pytest.ini` need to be recursed and any + matches be appended to the arguments list. + + """ + + def find_paths_matching_partial(partial_paths): + partial_path_file_patterns = [] + for partial_path in partial_paths: + if len(partial_path) >= 4: + has_test_prefix = partial_path[:4] == 'test' + has_py_suffix = partial_path[-3:] == '.py' + elif len(partial_path) >= 3: + has_test_prefix = False + has_py_suffix = partial_path[-3:] == '.py' + else: + has_test_prefix = False + has_py_suffix = False + if has_test_prefix and has_py_suffix: + partial_path_file_patterns.append(partial_path) + elif has_test_prefix: + partial_path_file_patterns.append(f'{partial_path}*.py') + elif has_py_suffix: + partial_path_file_patterns.append(f'test*{partial_path}') + else: + partial_path_file_patterns.append(f'test*{partial_path}*.py') + matches = [] + for testpath in valid_testpaths_default: + for path, dirs, files in os.walk(testpath, topdown=True): + zipped = zip(partial_paths, partial_path_file_patterns) + for (partial_path, partial_path_file) in zipped: + if fnmatch(path, f'*{partial_path}*'): + matches.append(str(pathlib.Path(path))) + dirs[:] = [] + else: + for file in files: + if fnmatch(file, partial_path_file): + matches.append(str(pathlib.Path(path, file))) + return matches + + def is_tests_file(filepath: str) -> bool: + path = pathlib.Path(filepath) + if not path.is_file(): + return False + if not path.parts[-1].startswith('test_'): + return False + if not path.suffix == '.py': + return False + return True + + def find_tests_matching_keywords(keywords, filepath): + matches = [] + source = pathlib.Path(filepath).read_text(encoding='utf-8') + for line in source.splitlines(): + if line.lstrip().startswith('def '): + for kw in keywords: + if line.lower().find(kw.lower()) != -1: + test_name = line.split(' ')[1].split('(')[0] + full_test_path = filepath + '::' + test_name + matches.append(full_test_path) + return matches + + valid_testpaths_default = [] + for testpath in TESTPATHS_DEFAULT: + absolute_testpath = pathlib.Path(sympy_dir(), testpath) + if absolute_testpath.exists(): + valid_testpaths_default.append(str(absolute_testpath)) + + candidate_paths = [] + if paths: + full_paths = [] + partial_paths = [] + for path in paths: + if pathlib.Path(path).exists(): + full_paths.append(str(pathlib.Path(sympy_dir(), path))) + else: + partial_paths.append(path) + matched_paths = find_paths_matching_partial(partial_paths) + candidate_paths.extend(full_paths) + candidate_paths.extend(matched_paths) + else: + candidate_paths.extend(valid_testpaths_default) + + if keywords is not None and keywords != (): + matches = [] + for path in candidate_paths: + if is_tests_file(path): + test_matches = find_tests_matching_keywords(keywords, path) + matches.extend(test_matches) + else: + for root, dirnames, filenames in os.walk(path): + for filename in filenames: + absolute_filepath = str(pathlib.Path(root, filename)) + if is_tests_file(absolute_filepath): + test_matches = find_tests_matching_keywords( + keywords, + absolute_filepath, + ) + matches.extend(test_matches) + args.extend(matches) + else: + args.extend(candidate_paths) + + return args + + +def make_absolute_path(partial_path: str) -> str: + """Convert a partial path to an absolute path. + + A path such a `sympy/core` might be needed. However, absolute paths should + be used in the arguments to pytest in all cases as it avoids errors that + arise from nonexistent paths. + + This function assumes that partial_paths will be passed in such that they + begin with the explicit `sympy` directory, i.e. `sympy/...`. + + """ + + def is_valid_partial_path(partial_path: str) -> bool: + """Assumption that partial paths are defined from the `sympy` root.""" + return pathlib.Path(partial_path).parts[0] == 'sympy' + + if not is_valid_partial_path(partial_path): + msg = ( + f'Partial path {dir(partial_path)} is invalid, partial paths are ' + f'expected to be defined with the `sympy` directory as the root.' + ) + raise ValueError(msg) + + absolute_path = str(pathlib.Path(sympy_dir(), partial_path)) + return absolute_path + + +def test(*paths, subprocess=True, rerun=0, **kwargs): + """Interface to run tests via pytest compatible with SymPy's test runner. + + Explanation + =========== + + Note that a `pytest.ExitCode`, which is an `enum`, is returned. This is + different to the legacy SymPy test runner which would return a `bool`. If + all tests successfully pass the `pytest.ExitCode.OK` with value `0` is + returned, whereas the legacy SymPy test runner would return `True`. In any + other scenario, a non-zero `enum` value is returned, whereas the legacy + SymPy test runner would return `False`. Users need to, therefore, be careful + if treating the pytest exit codes as booleans because + `bool(pytest.ExitCode.OK)` evaluates to `False`, the opposite of legacy + behaviour. + + Examples + ======== + + >>> import sympy # doctest: +SKIP + + Run one file: + + >>> sympy.test('sympy/core/tests/test_basic.py') # doctest: +SKIP + >>> sympy.test('_basic') # doctest: +SKIP + + Run all tests in sympy/functions/ and some particular file: + + >>> sympy.test("sympy/core/tests/test_basic.py", + ... "sympy/functions") # doctest: +SKIP + + Run all tests in sympy/core and sympy/utilities: + + >>> sympy.test("/core", "/util") # doctest: +SKIP + + Run specific test from a file: + + >>> sympy.test("sympy/core/tests/test_basic.py", + ... kw="test_equality") # doctest: +SKIP + + Run specific test from any file: + + >>> sympy.test(kw="subs") # doctest: +SKIP + + Run the tests using the legacy SymPy runner: + + >>> sympy.test(use_sympy_runner=True) # doctest: +SKIP + + Note that this option is slated for deprecation in the near future and is + only currently provided to ensure users have an alternative option while the + pytest-based runner receives real-world testing. + + Parameters + ========== + paths : first n positional arguments of strings + Paths, both partial and absolute, describing which subset(s) of the test + suite are to be run. + subprocess : bool, default is True + Legacy option, is currently ignored. + rerun : int, default is 0 + Legacy option, is ignored. + use_sympy_runner : bool or None, default is None + Temporary option to invoke the legacy SymPy test runner instead of + `pytest.main`. Will be removed in the near future. + verbose : bool, default is False + Sets the verbosity of the pytest output. Using `True` will add the + `--verbose` option to the pytest call. + tb : str, 'auto', 'long', 'short', 'line', 'native', or 'no' + Sets the traceback print mode of pytest using the `--tb` option. + kw : str + Only run tests which match the given substring expression. An expression + is a Python evaluatable expression where all names are substring-matched + against test names and their parent classes. Example: -k 'test_method or + test_other' matches all test functions and classes whose name contains + 'test_method' or 'test_other', while -k 'not test_method' matches those + that don't contain 'test_method' in their names. -k 'not test_method and + not test_other' will eliminate the matches. Additionally keywords are + matched to classes and functions containing extra names in their + 'extra_keyword_matches' set, as well as functions which have names + assigned directly to them. The matching is case-insensitive. + pdb : bool, default is False + Start the interactive Python debugger on errors or `KeyboardInterrupt`. + colors : bool, default is True + Color terminal output. + force_colors : bool, default is False + Legacy option, is ignored. + sort : bool, default is True + Run the tests in sorted order. pytest uses a sorted test order by + default. Requires pytest-randomly. + seed : int + Seed to use for random number generation. Requires pytest-randomly. + timeout : int, default is 0 + Timeout in seconds before dumping the stacks. 0 means no timeout. + Requires pytest-timeout. + fail_on_timeout : bool, default is False + Legacy option, is currently ignored. + slow : bool, default is False + Run the subset of tests marked as `slow`. + enhance_asserts : bool, default is False + Legacy option, is currently ignored. + split : string in form `/` or None, default is None + Used to split the tests up. As an example, if `split='2/3' is used then + only the middle third of tests are run. Requires pytest-split. + time_balance : bool, default is True + Legacy option, is currently ignored. + blacklist : iterable of test paths as strings, default is BLACKLIST_DEFAULT + Blacklisted test paths are ignored using the `--ignore` option. Paths + may be partial or absolute. If partial then they are matched against + all paths in the pytest tests path. + parallel : bool, default is False + Parallelize the test running using pytest-xdist. If `True` then pytest + will automatically detect the number of CPU cores available and use them + all. Requires pytest-xdist. + store_durations : bool, False + Store test durations into the file `.test_durations`. The is used by + `pytest-split` to help determine more even splits when more than one + test group is being used. Requires pytest-split. + + """ + # NOTE: to be removed alongside SymPy test runner + if kwargs.get('use_sympy_runner', False): + kwargs.pop('parallel', False) + kwargs.pop('store_durations', False) + kwargs.pop('use_sympy_runner', True) + if kwargs.get('slow') is None: + kwargs['slow'] = False + return test_sympy(*paths, subprocess=True, rerun=0, **kwargs) + + pytest_plugin_manager = PytestPluginManager() + if not pytest_plugin_manager.has_pytest: + pytest.main() + + args = [] + + if kwargs.get('verbose', False): + args.append('--verbose') + + if tb := kwargs.get('tb'): + args.extend(['--tb', tb]) + + if kwargs.get('pdb'): + args.append('--pdb') + + if not kwargs.get('colors', True): + args.extend(['--color', 'no']) + + if seed := kwargs.get('seed'): + if not pytest_plugin_manager.has_randomly: + msg = '`pytest-randomly` plugin required to control random seed.' + raise ModuleNotFoundError(msg) + args.extend(['--randomly-seed', str(seed)]) + + if kwargs.get('sort', True) and pytest_plugin_manager.has_randomly: + args.append('--randomly-dont-reorganize') + elif not kwargs.get('sort', True) and not pytest_plugin_manager.has_randomly: + msg = '`pytest-randomly` plugin required to randomize test order.' + raise ModuleNotFoundError(msg) + + if timeout := kwargs.get('timeout', None): + if not pytest_plugin_manager.has_timeout: + msg = '`pytest-timeout` plugin required to apply timeout to tests.' + raise ModuleNotFoundError(msg) + args.extend(['--timeout', str(int(timeout))]) + + # Skip slow tests by default and always skip tooslow tests + if kwargs.get('slow', False): + args.extend(['-m', 'slow and not tooslow']) + else: + args.extend(['-m', 'not slow and not tooslow']) + + if (split := kwargs.get('split')) is not None: + if not pytest_plugin_manager.has_split: + msg = '`pytest-split` plugin required to run tests as groups.' + raise ModuleNotFoundError(msg) + match = split_pattern.match(split) + if not match: + msg = ('split must be a string of the form a/b where a and b are ' + 'positive nonzero ints') + raise ValueError(msg) + group, splits = map(str, match.groups()) + args.extend(['--group', group, '--splits', splits]) + if group > splits: + msg = (f'cannot have a group number {group} with only {splits} ' + 'splits') + raise ValueError(msg) + + if blacklist := kwargs.get('blacklist', BLACKLIST_DEFAULT): + for path in blacklist: + args.extend(['--ignore', make_absolute_path(path)]) + + if kwargs.get('parallel', False): + if not pytest_plugin_manager.has_xdist: + msg = '`pytest-xdist` plugin required to run tests in parallel.' + raise ModuleNotFoundError(msg) + args.extend(['-n', 'auto']) + + if kwargs.get('store_durations', False): + if not pytest_plugin_manager.has_split: + msg = '`pytest-split` plugin required to store test durations.' + raise ModuleNotFoundError(msg) + args.append('--store-durations') + + if (keywords := kwargs.get('kw')) is not None: + keywords = tuple(str(kw) for kw in keywords) + else: + keywords = () + + args = update_args_with_paths(paths, keywords, args) + exit_code = pytest.main(args) + return exit_code + + +def doctest(): + """Interface to run doctests via pytest compatible with SymPy's test runner. + """ + raise NotImplementedError diff --git a/.venv/lib/python3.13/site-packages/sympy/testing/tmpfiles.py b/.venv/lib/python3.13/site-packages/sympy/testing/tmpfiles.py new file mode 100644 index 0000000000000000000000000000000000000000..1d5c69cb58aa11f77679855f3df21f03a10d3b2b --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/testing/tmpfiles.py @@ -0,0 +1,46 @@ +""" +This module adds context manager for temporary files generated by the tests. +""" + +import shutil +import os + + +class TmpFileManager: + """ + A class to track record of every temporary files created by the tests. + """ + tmp_files = set('') + tmp_folders = set('') + + @classmethod + def tmp_file(cls, name=''): + cls.tmp_files.add(name) + return name + + @classmethod + def tmp_folder(cls, name=''): + cls.tmp_folders.add(name) + return name + + @classmethod + def cleanup(cls): + while cls.tmp_files: + file = cls.tmp_files.pop() + if os.path.isfile(file): + os.remove(file) + while cls.tmp_folders: + folder = cls.tmp_folders.pop() + shutil.rmtree(folder) + +def cleanup_tmp_files(test_func): + """ + A decorator to help test codes remove temporary files after the tests. + """ + def wrapper_function(): + try: + test_func() + finally: + TmpFileManager.cleanup() + + return wrapper_function diff --git a/.venv/lib/python3.13/site-packages/sympy/utilities/__init__.py b/.venv/lib/python3.13/site-packages/sympy/utilities/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8f35da4a84396618a33a12c3c6b5cf58e9d4742c --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/utilities/__init__.py @@ -0,0 +1,30 @@ +"""This module contains some general purpose utilities that are used across +SymPy. +""" +from .iterables import (flatten, group, take, subsets, + variations, numbered_symbols, cartes, capture, dict_merge, + prefixes, postfixes, sift, topological_sort, unflatten, + has_dups, has_variety, reshape, rotations) + +from .misc import filldedent + +from .lambdify import lambdify + +from .decorator import threaded, xthreaded, public, memoize_property + +from .timeutils import timed + +__all__ = [ + 'flatten', 'group', 'take', 'subsets', 'variations', 'numbered_symbols', + 'cartes', 'capture', 'dict_merge', 'prefixes', 'postfixes', 'sift', + 'topological_sort', 'unflatten', 'has_dups', 'has_variety', 'reshape', + 'rotations', + + 'filldedent', + + 'lambdify', + + 'threaded', 'xthreaded', 'public', 'memoize_property', + + 'timed', +] diff --git a/.venv/lib/python3.13/site-packages/sympy/utilities/autowrap.py b/.venv/lib/python3.13/site-packages/sympy/utilities/autowrap.py new file mode 100644 index 0000000000000000000000000000000000000000..b6c33b2f0f72f89b5680f910929618a821e924c3 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/utilities/autowrap.py @@ -0,0 +1,1178 @@ +"""Module for compiling codegen output, and wrap the binary for use in +python. + +.. note:: To use the autowrap module it must first be imported + + >>> from sympy.utilities.autowrap import autowrap + +This module provides a common interface for different external backends, such +as f2py, fwrap, Cython, SWIG(?) etc. (Currently only f2py and Cython are +implemented) The goal is to provide access to compiled binaries of acceptable +performance with a one-button user interface, e.g., + + >>> from sympy.abc import x,y + >>> expr = (x - y)**25 + >>> flat = expr.expand() + >>> binary_callable = autowrap(flat) + >>> binary_callable(2, 3) + -1.0 + +Although a SymPy user might primarily be interested in working with +mathematical expressions and not in the details of wrapping tools +needed to evaluate such expressions efficiently in numerical form, +the user cannot do so without some understanding of the +limits in the target language. For example, the expanded expression +contains large coefficients which result in loss of precision when +computing the expression: + + >>> binary_callable(3, 2) + 0.0 + >>> binary_callable(4, 5), binary_callable(5, 4) + (-22925376.0, 25165824.0) + +Wrapping the unexpanded expression gives the expected behavior: + + >>> e = autowrap(expr) + >>> e(4, 5), e(5, 4) + (-1.0, 1.0) + +The callable returned from autowrap() is a binary Python function, not a +SymPy object. If it is desired to use the compiled function in symbolic +expressions, it is better to use binary_function() which returns a SymPy +Function object. The binary callable is attached as the _imp_ attribute and +invoked when a numerical evaluation is requested with evalf(), or with +lambdify(). + + >>> from sympy.utilities.autowrap import binary_function + >>> f = binary_function('f', expr) + >>> 2*f(x, y) + y + y + 2*f(x, y) + >>> (2*f(x, y) + y).evalf(2, subs={x: 1, y:2}) + 0.e-110 + +When is this useful? + + 1) For computations on large arrays, Python iterations may be too slow, + and depending on the mathematical expression, it may be difficult to + exploit the advanced index operations provided by NumPy. + + 2) For *really* long expressions that will be called repeatedly, the + compiled binary should be significantly faster than SymPy's .evalf() + + 3) If you are generating code with the codegen utility in order to use + it in another project, the automatic Python wrappers let you test the + binaries immediately from within SymPy. + + 4) To create customized ufuncs for use with numpy arrays. + See *ufuncify*. + +When is this module NOT the best approach? + + 1) If you are really concerned about speed or memory optimizations, + you will probably get better results by working directly with the + wrapper tools and the low level code. However, the files generated + by this utility may provide a useful starting point and reference + code. Temporary files will be left intact if you supply the keyword + tempdir="path/to/files/". + + 2) If the array computation can be handled easily by numpy, and you + do not need the binaries for another project. + +""" + +import sys +import os +import shutil +import tempfile +from pathlib import Path +from subprocess import STDOUT, CalledProcessError, check_output +from string import Template +from warnings import warn + +from sympy.core.cache import cacheit +from sympy.core.function import Lambda +from sympy.core.relational import Eq +from sympy.core.symbol import Dummy, Symbol +from sympy.tensor.indexed import Idx, IndexedBase +from sympy.utilities.codegen import (make_routine, get_code_generator, + OutputArgument, InOutArgument, + InputArgument, CodeGenArgumentListError, + Result, ResultBase, C99CodeGen) +from sympy.utilities.iterables import iterable +from sympy.utilities.lambdify import implemented_function +from sympy.utilities.decorator import doctest_depends_on + +_doctest_depends_on = {'exe': ('f2py', 'gfortran', 'gcc'), + 'modules': ('numpy',)} + + +class CodeWrapError(Exception): + pass + + +class CodeWrapper: + """Base Class for code wrappers""" + _filename = "wrapped_code" + _module_basename = "wrapper_module" + _module_counter = 0 + + @property + def filename(self): + return "%s_%s" % (self._filename, CodeWrapper._module_counter) + + @property + def module_name(self): + return "%s_%s" % (self._module_basename, CodeWrapper._module_counter) + + def __init__(self, generator, filepath=None, flags=[], verbose=False): + """ + generator -- the code generator to use + """ + self.generator = generator + self.filepath = filepath + self.flags = flags + self.quiet = not verbose + + @property + def include_header(self): + return bool(self.filepath) + + @property + def include_empty(self): + return bool(self.filepath) + + def _generate_code(self, main_routine, routines): + routines.append(main_routine) + self.generator.write( + routines, self.filename, True, self.include_header, + self.include_empty) + + def wrap_code(self, routine, helpers=None): + helpers = helpers or [] + if self.filepath: + workdir = os.path.abspath(self.filepath) + else: + workdir = tempfile.mkdtemp("_sympy_compile") + if not os.access(workdir, os.F_OK): + os.mkdir(workdir) + oldwork = os.getcwd() + os.chdir(workdir) + try: + sys.path.append(workdir) + self._generate_code(routine, helpers) + self._prepare_files(routine) + self._process_files(routine) + mod = __import__(self.module_name) + finally: + sys.path.remove(workdir) + CodeWrapper._module_counter += 1 + os.chdir(oldwork) + if not self.filepath: + try: + shutil.rmtree(workdir) + except OSError: + # Could be some issues on Windows + pass + + return self._get_wrapped_function(mod, routine.name) + + def _process_files(self, routine): + command = self.command + command.extend(self.flags) + try: + retoutput = check_output(command, stderr=STDOUT) + except CalledProcessError as e: + raise CodeWrapError( + "Error while executing command: %s. Command output is:\n%s" % ( + " ".join(command), e.output.decode('utf-8'))) + if not self.quiet: + print(retoutput) + + +class DummyWrapper(CodeWrapper): + """Class used for testing independent of backends """ + + template = """# dummy module for testing of SymPy +def %(name)s(): + return "%(expr)s" +%(name)s.args = "%(args)s" +%(name)s.returns = "%(retvals)s" +""" + + def _prepare_files(self, routine): + return + + def _generate_code(self, routine, helpers): + with open('%s.py' % self.module_name, 'w') as f: + printed = ", ".join( + [str(res.expr) for res in routine.result_variables]) + # convert OutputArguments to return value like f2py + args = filter(lambda x: not isinstance( + x, OutputArgument), routine.arguments) + retvals = [] + for val in routine.result_variables: + if isinstance(val, Result): + retvals.append('nameless') + else: + retvals.append(val.result_var) + + print(DummyWrapper.template % { + 'name': routine.name, + 'expr': printed, + 'args': ", ".join([str(a.name) for a in args]), + 'retvals': ", ".join([str(val) for val in retvals]) + }, end="", file=f) + + def _process_files(self, routine): + return + + @classmethod + def _get_wrapped_function(cls, mod, name): + return getattr(mod, name) + + +class CythonCodeWrapper(CodeWrapper): + """Wrapper that uses Cython""" + + setup_template = """\ +from setuptools import setup +from setuptools import Extension +from Cython.Build import cythonize +cy_opts = {cythonize_options} +{np_import} +ext_mods = [Extension( + {ext_args}, + include_dirs={include_dirs}, + library_dirs={library_dirs}, + libraries={libraries}, + extra_compile_args={extra_compile_args}, + extra_link_args={extra_link_args} +)] +setup(ext_modules=cythonize(ext_mods, **cy_opts)) +""" + + _cythonize_options = {'compiler_directives':{'language_level' : "3"}} + + pyx_imports = ( + "import numpy as np\n" + "cimport numpy as np\n\n") + + pyx_header = ( + "cdef extern from '{header_file}.h':\n" + " {prototype}\n\n") + + pyx_func = ( + "def {name}_c({arg_string}):\n" + "\n" + "{declarations}" + "{body}") + + std_compile_flag = '-std=c99' + + def __init__(self, *args, **kwargs): + """Instantiates a Cython code wrapper. + + The following optional parameters get passed to ``setuptools.Extension`` + for building the Python extension module. Read its documentation to + learn more. + + Parameters + ========== + include_dirs : [list of strings] + A list of directories to search for C/C++ header files (in Unix + form for portability). + library_dirs : [list of strings] + A list of directories to search for C/C++ libraries at link time. + libraries : [list of strings] + A list of library names (not filenames or paths) to link against. + extra_compile_args : [list of strings] + Any extra platform- and compiler-specific information to use when + compiling the source files in 'sources'. For platforms and + compilers where "command line" makes sense, this is typically a + list of command-line arguments, but for other platforms it could be + anything. Note that the attribute ``std_compile_flag`` will be + appended to this list. + extra_link_args : [list of strings] + Any extra platform- and compiler-specific information to use when + linking object files together to create the extension (or to create + a new static Python interpreter). Similar interpretation as for + 'extra_compile_args'. + cythonize_options : [dictionary] + Keyword arguments passed on to cythonize. + + """ + + self._include_dirs = kwargs.pop('include_dirs', []) + self._library_dirs = kwargs.pop('library_dirs', []) + self._libraries = kwargs.pop('libraries', []) + self._extra_compile_args = kwargs.pop('extra_compile_args', []) + self._extra_compile_args.append(self.std_compile_flag) + self._extra_link_args = kwargs.pop('extra_link_args', []) + self._cythonize_options = kwargs.pop('cythonize_options', self._cythonize_options) + + self._need_numpy = False + + super().__init__(*args, **kwargs) + + @property + def command(self): + command = [sys.executable, "setup.py", "build_ext", "--inplace"] + return command + + def _prepare_files(self, routine, build_dir=os.curdir): + # NOTE : build_dir is used for testing purposes. + pyxfilename = self.module_name + '.pyx' + codefilename = "%s.%s" % (self.filename, self.generator.code_extension) + + # pyx + with open(os.path.join(build_dir, pyxfilename), 'w') as f: + self.dump_pyx([routine], f, self.filename) + + # setup.py + ext_args = [repr(self.module_name), repr([pyxfilename, codefilename])] + if self._need_numpy: + np_import = 'import numpy as np\n' + self._include_dirs.append('np.get_include()') + else: + np_import = '' + + includes = str(self._include_dirs).replace("'np.get_include()'", + 'np.get_include()') + code = self.setup_template.format( + ext_args=", ".join(ext_args), + np_import=np_import, + include_dirs=includes, + library_dirs=self._library_dirs, + libraries=self._libraries, + extra_compile_args=self._extra_compile_args, + extra_link_args=self._extra_link_args, + cythonize_options=self._cythonize_options) + Path(os.path.join(build_dir, 'setup.py')).write_text(code) + + @classmethod + def _get_wrapped_function(cls, mod, name): + return getattr(mod, name + '_c') + + def dump_pyx(self, routines, f, prefix): + """Write a Cython file with Python wrappers + + This file contains all the definitions of the routines in c code and + refers to the header file. + + Arguments + --------- + routines + List of Routine instances + f + File-like object to write the file to + prefix + The filename prefix, used to refer to the proper header file. + Only the basename of the prefix is used. + """ + headers = [] + functions = [] + for routine in routines: + prototype = self.generator.get_prototype(routine) + + # C Function Header Import + headers.append(self.pyx_header.format(header_file=prefix, + prototype=prototype)) + + # Partition the C function arguments into categories + py_rets, py_args, py_loc, py_inf = self._partition_args(routine.arguments) + + # Function prototype + name = routine.name + arg_string = ", ".join(self._prototype_arg(arg) for arg in py_args) + + # Local Declarations + local_decs = [] + for arg, val in py_inf.items(): + proto = self._prototype_arg(arg) + mat, ind = [self._string_var(v) for v in val] + local_decs.append(" cdef {} = {}.shape[{}]".format(proto, mat, ind)) + local_decs.extend([" cdef {}".format(self._declare_arg(a)) for a in py_loc]) + declarations = "\n".join(local_decs) + if declarations: + declarations = declarations + "\n" + + # Function Body + args_c = ", ".join([self._call_arg(a) for a in routine.arguments]) + rets = ", ".join([self._string_var(r.name) for r in py_rets]) + if routine.results: + body = ' return %s(%s)' % (routine.name, args_c) + if rets: + body = body + ', ' + rets + else: + body = ' %s(%s)\n' % (routine.name, args_c) + body = body + ' return ' + rets + + functions.append(self.pyx_func.format(name=name, arg_string=arg_string, + declarations=declarations, body=body)) + + # Write text to file + if self._need_numpy: + # Only import numpy if required + f.write(self.pyx_imports) + f.write('\n'.join(headers)) + f.write('\n'.join(functions)) + + def _partition_args(self, args): + """Group function arguments into categories.""" + py_args = [] + py_returns = [] + py_locals = [] + py_inferred = {} + for arg in args: + if isinstance(arg, OutputArgument): + py_returns.append(arg) + py_locals.append(arg) + elif isinstance(arg, InOutArgument): + py_returns.append(arg) + py_args.append(arg) + else: + py_args.append(arg) + # Find arguments that are array dimensions. These can be inferred + # locally in the Cython code. + if isinstance(arg, (InputArgument, InOutArgument)) and arg.dimensions: + dims = [d[1] + 1 for d in arg.dimensions] + sym_dims = [(i, d) for (i, d) in enumerate(dims) if + isinstance(d, Symbol)] + for (i, d) in sym_dims: + py_inferred[d] = (arg.name, i) + for arg in args: + if arg.name in py_inferred: + py_inferred[arg] = py_inferred.pop(arg.name) + # Filter inferred arguments from py_args + py_args = [a for a in py_args if a not in py_inferred] + return py_returns, py_args, py_locals, py_inferred + + def _prototype_arg(self, arg): + mat_dec = "np.ndarray[{mtype}, ndim={ndim}] {name}" + np_types = {'double': 'np.double_t', + 'int': 'np.int_t'} + t = arg.get_datatype('c') + if arg.dimensions: + self._need_numpy = True + ndim = len(arg.dimensions) + mtype = np_types[t] + return mat_dec.format(mtype=mtype, ndim=ndim, name=self._string_var(arg.name)) + else: + return "%s %s" % (t, self._string_var(arg.name)) + + def _declare_arg(self, arg): + proto = self._prototype_arg(arg) + if arg.dimensions: + shape = '(' + ','.join(self._string_var(i[1] + 1) for i in arg.dimensions) + ')' + return proto + " = np.empty({shape})".format(shape=shape) + else: + return proto + " = 0" + + def _call_arg(self, arg): + if arg.dimensions: + t = arg.get_datatype('c') + return "<{}*> {}.data".format(t, self._string_var(arg.name)) + elif isinstance(arg, ResultBase): + return "&{}".format(self._string_var(arg.name)) + else: + return self._string_var(arg.name) + + def _string_var(self, var): + printer = self.generator.printer.doprint + return printer(var) + + +class F2PyCodeWrapper(CodeWrapper): + """Wrapper that uses f2py""" + + def __init__(self, *args, **kwargs): + + ext_keys = ['include_dirs', 'library_dirs', 'libraries', + 'extra_compile_args', 'extra_link_args'] + msg = ('The compilation option kwarg {} is not supported with the f2py ' + 'backend.') + + for k in ext_keys: + if k in kwargs.keys(): + warn(msg.format(k)) + kwargs.pop(k, None) + + super().__init__(*args, **kwargs) + + @property + def command(self): + filename = self.filename + '.' + self.generator.code_extension + args = ['-c', '-m', self.module_name, filename] + command = [sys.executable, "-c", "import numpy.f2py as f2py2e;f2py2e.main()"]+args + return command + + def _prepare_files(self, routine): + pass + + @classmethod + def _get_wrapped_function(cls, mod, name): + return getattr(mod, name) + + +# Here we define a lookup of backends -> tuples of languages. For now, each +# tuple is of length 1, but if a backend supports more than one language, +# the most preferable language is listed first. +_lang_lookup = {'CYTHON': ('C99', 'C89', 'C'), + 'F2PY': ('F95',), + 'NUMPY': ('C99', 'C89', 'C'), + 'DUMMY': ('F95',)} # Dummy here just for testing + + +def _infer_language(backend): + """For a given backend, return the top choice of language""" + langs = _lang_lookup.get(backend.upper(), False) + if not langs: + raise ValueError("Unrecognized backend: " + backend) + return langs[0] + + +def _validate_backend_language(backend, language): + """Throws error if backend and language are incompatible""" + langs = _lang_lookup.get(backend.upper(), False) + if not langs: + raise ValueError("Unrecognized backend: " + backend) + if language.upper() not in langs: + raise ValueError(("Backend {} and language {} are " + "incompatible").format(backend, language)) + + +@cacheit +@doctest_depends_on(exe=('f2py', 'gfortran'), modules=('numpy',)) +def autowrap(expr, language=None, backend='f2py', tempdir=None, args=None, + flags=None, verbose=False, helpers=None, code_gen=None, **kwargs): + """Generates Python callable binaries based on the math expression. + + Parameters + ========== + + expr + The SymPy expression that should be wrapped as a binary routine. + language : string, optional + If supplied, (options: 'C' or 'F95'), specifies the language of the + generated code. If ``None`` [default], the language is inferred based + upon the specified backend. + backend : string, optional + Backend used to wrap the generated code. Either 'f2py' [default], + or 'cython'. + tempdir : string, optional + Path to directory for temporary files. If this argument is supplied, + the generated code and the wrapper input files are left intact in the + specified path. + args : iterable, optional + An ordered iterable of symbols. Specifies the argument sequence for the + function. + flags : iterable, optional + Additional option flags that will be passed to the backend. + verbose : bool, optional + If True, autowrap will not mute the command line backends. This can be + helpful for debugging. + helpers : 3-tuple or iterable of 3-tuples, optional + Used to define auxiliary functions needed for the main expression. + Each tuple should be of the form (name, expr, args) where: + + - name : str, the function name + - expr : sympy expression, the function + - args : iterable, the function arguments (can be any iterable of symbols) + + code_gen : CodeGen instance + An instance of a CodeGen subclass. Overrides ``language``. + include_dirs : [string] + A list of directories to search for C/C++ header files (in Unix form + for portability). + library_dirs : [string] + A list of directories to search for C/C++ libraries at link time. + libraries : [string] + A list of library names (not filenames or paths) to link against. + extra_compile_args : [string] + Any extra platform- and compiler-specific information to use when + compiling the source files in 'sources'. For platforms and compilers + where "command line" makes sense, this is typically a list of + command-line arguments, but for other platforms it could be anything. + extra_link_args : [string] + Any extra platform- and compiler-specific information to use when + linking object files together to create the extension (or to create a + new static Python interpreter). Similar interpretation as for + 'extra_compile_args'. + + Examples + ======== + + Basic usage: + + >>> from sympy.abc import x, y, z + >>> from sympy.utilities.autowrap import autowrap + >>> expr = ((x - y + z)**(13)).expand() + >>> binary_func = autowrap(expr) + >>> binary_func(1, 4, 2) + -1.0 + + Using helper functions: + + >>> from sympy.abc import x, t + >>> from sympy import Function + >>> helper_func = Function('helper_func') # Define symbolic function + >>> expr = 3*x + helper_func(t) # Main expression using helper function + >>> # Define helper_func(x) = 4*x using f2py backend + >>> binary_func = autowrap(expr, args=[x, t], + ... helpers=('helper_func', 4*x, [x])) + >>> binary_func(2, 5) # 3*2 + helper_func(5) = 6 + 20 + 26.0 + >>> # Same example using cython backend + >>> binary_func = autowrap(expr, args=[x, t], backend='cython', + ... helpers=[('helper_func', 4*x, [x])]) + >>> binary_func(2, 5) # 3*2 + helper_func(5) = 6 + 20 + 26.0 + + Type handling example: + + >>> import numpy as np + >>> expr = x + y + >>> f_cython = autowrap(expr, backend='cython') + >>> f_cython(1, 2) # doctest: +ELLIPSIS + Traceback (most recent call last): + ... + TypeError: Argument '_x' has incorrect type (expected numpy.ndarray, got int) + >>> f_cython(np.array([1.0]), np.array([2.0])) + array([ 3.]) + + """ + if language: + if not isinstance(language, type): + _validate_backend_language(backend, language) + else: + language = _infer_language(backend) + + # two cases 1) helpers is an iterable of 3-tuples and 2) helpers is a + # 3-tuple + if iterable(helpers) and len(helpers) != 0 and iterable(helpers[0]): + helpers = helpers if helpers else () + else: + helpers = [helpers] if helpers else () + args = list(args) if iterable(args, exclude=set) else args + + if code_gen is None: + code_gen = get_code_generator(language, "autowrap") + + CodeWrapperClass = { + 'F2PY': F2PyCodeWrapper, + 'CYTHON': CythonCodeWrapper, + 'DUMMY': DummyWrapper + }[backend.upper()] + code_wrapper = CodeWrapperClass(code_gen, tempdir, flags if flags else (), + verbose, **kwargs) + + helps = [] + for name_h, expr_h, args_h in helpers: + helps.append(code_gen.routine(name_h, expr_h, args_h)) + + for name_h, expr_h, args_h in helpers: + if expr.has(expr_h): + name_h = binary_function(name_h, expr_h, backend='dummy') + expr = expr.subs(expr_h, name_h(*args_h)) + try: + routine = code_gen.routine('autofunc', expr, args) + except CodeGenArgumentListError as e: + # if all missing arguments are for pure output, we simply attach them + # at the end and try again, because the wrappers will silently convert + # them to return values anyway. + new_args = [] + for missing in e.missing_args: + if not isinstance(missing, OutputArgument): + raise + new_args.append(missing.name) + routine = code_gen.routine('autofunc', expr, args + new_args) + + return code_wrapper.wrap_code(routine, helpers=helps) + + +@doctest_depends_on(exe=('f2py', 'gfortran'), modules=('numpy',)) +def binary_function(symfunc, expr, **kwargs): + """Returns a SymPy function with expr as binary implementation + + This is a convenience function that automates the steps needed to + autowrap the SymPy expression and attaching it to a Function object + with implemented_function(). + + Parameters + ========== + + symfunc : SymPy Function + The function to bind the callable to. + expr : SymPy Expression + The expression used to generate the function. + kwargs : dict + Any kwargs accepted by autowrap. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy.utilities.autowrap import binary_function + >>> expr = ((x - y)**(25)).expand() + >>> f = binary_function('f', expr) + >>> type(f) + + >>> 2*f(x, y) + 2*f(x, y) + >>> f(x, y).evalf(2, subs={x: 1, y: 2}) + -1.0 + + """ + binary = autowrap(expr, **kwargs) + return implemented_function(symfunc, binary) + +################################################################# +# UFUNCIFY # +################################################################# + +_ufunc_top = Template("""\ +#include "Python.h" +#include "math.h" +#include "numpy/ndarraytypes.h" +#include "numpy/ufuncobject.h" +#include "numpy/halffloat.h" +#include ${include_file} + +static PyMethodDef ${module}Methods[] = { + {NULL, NULL, 0, NULL} +};""") + +_ufunc_outcalls = Template("*((double *)out${outnum}) = ${funcname}(${call_args});") + +_ufunc_body = Template("""\ +#ifdef NPY_1_19_API_VERSION +static void ${funcname}_ufunc(char **args, const npy_intp *dimensions, const npy_intp* steps, void* data) +#else +static void ${funcname}_ufunc(char **args, npy_intp *dimensions, npy_intp* steps, void* data) +#endif +{ + npy_intp i; + npy_intp n = dimensions[0]; + ${declare_args} + ${declare_steps} + for (i = 0; i < n; i++) { + ${outcalls} + ${step_increments} + } +} +PyUFuncGenericFunction ${funcname}_funcs[1] = {&${funcname}_ufunc}; +static char ${funcname}_types[${n_types}] = ${types} +static void *${funcname}_data[1] = {NULL};""") + +_ufunc_bottom = Template("""\ +#if PY_VERSION_HEX >= 0x03000000 +static struct PyModuleDef moduledef = { + PyModuleDef_HEAD_INIT, + "${module}", + NULL, + -1, + ${module}Methods, + NULL, + NULL, + NULL, + NULL +}; + +PyMODINIT_FUNC PyInit_${module}(void) +{ + PyObject *m, *d; + ${function_creation} + m = PyModule_Create(&moduledef); + if (!m) { + return NULL; + } + import_array(); + import_umath(); + d = PyModule_GetDict(m); + ${ufunc_init} + return m; +} +#else +PyMODINIT_FUNC init${module}(void) +{ + PyObject *m, *d; + ${function_creation} + m = Py_InitModule("${module}", ${module}Methods); + if (m == NULL) { + return; + } + import_array(); + import_umath(); + d = PyModule_GetDict(m); + ${ufunc_init} +} +#endif\ +""") + +_ufunc_init_form = Template("""\ +ufunc${ind} = PyUFunc_FromFuncAndData(${funcname}_funcs, ${funcname}_data, ${funcname}_types, 1, ${n_in}, ${n_out}, + PyUFunc_None, "${module}", ${docstring}, 0); + PyDict_SetItemString(d, "${funcname}", ufunc${ind}); + Py_DECREF(ufunc${ind});""") + +_ufunc_setup = Template("""\ +from setuptools.extension import Extension +from setuptools import setup + +from numpy import get_include + +if __name__ == "__main__": + setup(ext_modules=[ + Extension('${module}', + sources=['${module}.c', '${filename}.c'], + include_dirs=[get_include()])]) +""") + + +class UfuncifyCodeWrapper(CodeWrapper): + """Wrapper for Ufuncify""" + + def __init__(self, *args, **kwargs): + + ext_keys = ['include_dirs', 'library_dirs', 'libraries', + 'extra_compile_args', 'extra_link_args'] + msg = ('The compilation option kwarg {} is not supported with the numpy' + ' backend.') + + for k in ext_keys: + if k in kwargs.keys(): + warn(msg.format(k)) + kwargs.pop(k, None) + + super().__init__(*args, **kwargs) + + @property + def command(self): + command = [sys.executable, "setup.py", "build_ext", "--inplace"] + return command + + def wrap_code(self, routines, helpers=None): + # This routine overrides CodeWrapper because we can't assume funcname == routines[0].name + # Therefore we have to break the CodeWrapper private API. + # There isn't an obvious way to extend multi-expr support to + # the other autowrap backends, so we limit this change to ufuncify. + helpers = helpers if helpers is not None else [] + # We just need a consistent name + funcname = 'wrapped_' + str(id(routines) + id(helpers)) + + workdir = self.filepath or tempfile.mkdtemp("_sympy_compile") + if not os.access(workdir, os.F_OK): + os.mkdir(workdir) + oldwork = os.getcwd() + os.chdir(workdir) + try: + sys.path.append(workdir) + self._generate_code(routines, helpers) + self._prepare_files(routines, funcname) + self._process_files(routines) + mod = __import__(self.module_name) + finally: + sys.path.remove(workdir) + CodeWrapper._module_counter += 1 + os.chdir(oldwork) + if not self.filepath: + try: + shutil.rmtree(workdir) + except OSError: + # Could be some issues on Windows + pass + + return self._get_wrapped_function(mod, funcname) + + def _generate_code(self, main_routines, helper_routines): + all_routines = main_routines + helper_routines + self.generator.write( + all_routines, self.filename, True, self.include_header, + self.include_empty) + + def _prepare_files(self, routines, funcname): + + # C + codefilename = self.module_name + '.c' + with open(codefilename, 'w') as f: + self.dump_c(routines, f, self.filename, funcname=funcname) + + # setup.py + with open('setup.py', 'w') as f: + self.dump_setup(f) + + @classmethod + def _get_wrapped_function(cls, mod, name): + return getattr(mod, name) + + def dump_setup(self, f): + setup = _ufunc_setup.substitute(module=self.module_name, + filename=self.filename) + f.write(setup) + + def dump_c(self, routines, f, prefix, funcname=None): + """Write a C file with Python wrappers + + This file contains all the definitions of the routines in c code. + + Arguments + --------- + routines + List of Routine instances + f + File-like object to write the file to + prefix + The filename prefix, used to name the imported module. + funcname + Name of the main function to be returned. + """ + if funcname is None: + if len(routines) == 1: + funcname = routines[0].name + else: + msg = 'funcname must be specified for multiple output routines' + raise ValueError(msg) + functions = [] + function_creation = [] + ufunc_init = [] + module = self.module_name + include_file = "\"{}.h\"".format(prefix) + top = _ufunc_top.substitute(include_file=include_file, module=module) + + name = funcname + + # Partition the C function arguments into categories + # Here we assume all routines accept the same arguments + r_index = 0 + py_in, _ = self._partition_args(routines[0].arguments) + n_in = len(py_in) + n_out = len(routines) + + # Declare Args + form = "char *{0}{1} = args[{2}];" + arg_decs = [form.format('in', i, i) for i in range(n_in)] + arg_decs.extend([form.format('out', i, i+n_in) for i in range(n_out)]) + declare_args = '\n '.join(arg_decs) + + # Declare Steps + form = "npy_intp {0}{1}_step = steps[{2}];" + step_decs = [form.format('in', i, i) for i in range(n_in)] + step_decs.extend([form.format('out', i, i+n_in) for i in range(n_out)]) + declare_steps = '\n '.join(step_decs) + + # Call Args + form = "*(double *)in{0}" + call_args = ', '.join([form.format(a) for a in range(n_in)]) + + # Step Increments + form = "{0}{1} += {0}{1}_step;" + step_incs = [form.format('in', i) for i in range(n_in)] + step_incs.extend([form.format('out', i, i) for i in range(n_out)]) + step_increments = '\n '.join(step_incs) + + # Types + n_types = n_in + n_out + types = "{" + ', '.join(["NPY_DOUBLE"]*n_types) + "};" + + # Docstring + docstring = '"Created in SymPy with Ufuncify"' + + # Function Creation + function_creation.append("PyObject *ufunc{};".format(r_index)) + + # Ufunc initialization + init_form = _ufunc_init_form.substitute(module=module, + funcname=name, + docstring=docstring, + n_in=n_in, n_out=n_out, + ind=r_index) + ufunc_init.append(init_form) + + outcalls = [_ufunc_outcalls.substitute( + outnum=i, call_args=call_args, funcname=routines[i].name) for i in + range(n_out)] + + body = _ufunc_body.substitute(module=module, funcname=name, + declare_args=declare_args, + declare_steps=declare_steps, + call_args=call_args, + step_increments=step_increments, + n_types=n_types, types=types, + outcalls='\n '.join(outcalls)) + functions.append(body) + + body = '\n\n'.join(functions) + ufunc_init = '\n '.join(ufunc_init) + function_creation = '\n '.join(function_creation) + bottom = _ufunc_bottom.substitute(module=module, + ufunc_init=ufunc_init, + function_creation=function_creation) + text = [top, body, bottom] + f.write('\n\n'.join(text)) + + def _partition_args(self, args): + """Group function arguments into categories.""" + py_in = [] + py_out = [] + for arg in args: + if isinstance(arg, OutputArgument): + py_out.append(arg) + elif isinstance(arg, InOutArgument): + raise ValueError("Ufuncify doesn't support InOutArguments") + else: + py_in.append(arg) + return py_in, py_out + + +@cacheit +@doctest_depends_on(exe=('f2py', 'gfortran', 'gcc'), modules=('numpy',)) +def ufuncify(args, expr, language=None, backend='numpy', tempdir=None, + flags=None, verbose=False, helpers=None, **kwargs): + """Generates a binary function that supports broadcasting on numpy arrays. + + Parameters + ========== + + args : iterable + Either a Symbol or an iterable of symbols. Specifies the argument + sequence for the function. + expr + A SymPy expression that defines the element wise operation. + language : string, optional + If supplied, (options: 'C' or 'F95'), specifies the language of the + generated code. If ``None`` [default], the language is inferred based + upon the specified backend. + backend : string, optional + Backend used to wrap the generated code. Either 'numpy' [default], + 'cython', or 'f2py'. + tempdir : string, optional + Path to directory for temporary files. If this argument is supplied, + the generated code and the wrapper input files are left intact in + the specified path. + flags : iterable, optional + Additional option flags that will be passed to the backend. + verbose : bool, optional + If True, autowrap will not mute the command line backends. This can + be helpful for debugging. + helpers : 3-tuple or iterable of 3-tuples, optional + Used to define auxiliary functions needed for the main expression. + Each tuple should be of the form (name, expr, args) where: + + - name : str, the function name + - expr : sympy expression, the function + - args : iterable, the function arguments (can be any iterable of symbols) + + kwargs : dict + These kwargs will be passed to autowrap if the `f2py` or `cython` + backend is used and ignored if the `numpy` backend is used. + + Notes + ===== + + The default backend ('numpy') will create actual instances of + ``numpy.ufunc``. These support ndimensional broadcasting, and implicit type + conversion. Use of the other backends will result in a "ufunc-like" + function, which requires equal length 1-dimensional arrays for all + arguments, and will not perform any type conversions. + + References + ========== + + .. [1] https://numpy.org/doc/stable/reference/ufuncs.html + + Examples + ======== + + Basic usage: + + >>> from sympy.utilities.autowrap import ufuncify + >>> from sympy.abc import x, y + >>> import numpy as np + >>> f = ufuncify((x, y), y + x**2) + >>> type(f) + + >>> f([1, 2, 3], 2) + array([ 3., 6., 11.]) + >>> f(np.arange(5), 3) + array([ 3., 4., 7., 12., 19.]) + + Using helper functions: + + >>> from sympy import Function + >>> helper_func = Function('helper_func') # Define symbolic function + >>> expr = x**2 + y*helper_func(x) # Main expression using helper function + >>> # Define helper_func(x) = x**3 + >>> f = ufuncify((x, y), expr, helpers=[('helper_func', x**3, [x])]) + >>> f([1, 2], [3, 4]) + array([ 4., 36.]) + + Type handling with different backends: + + For the 'f2py' and 'cython' backends, inputs are required to be equal length + 1-dimensional arrays. The 'f2py' backend will perform type conversion, but + the Cython backend will error if the inputs are not of the expected type. + + >>> f_fortran = ufuncify((x, y), y + x**2, backend='f2py') + >>> f_fortran(1, 2) + array([ 3.]) + >>> f_fortran(np.array([1, 2, 3]), np.array([1.0, 2.0, 3.0])) + array([ 2., 6., 12.]) + >>> f_cython = ufuncify((x, y), y + x**2, backend='Cython') + >>> f_cython(1, 2) # doctest: +ELLIPSIS + Traceback (most recent call last): + ... + TypeError: Argument '_x' has incorrect type (expected numpy.ndarray, got int) + >>> f_cython(np.array([1.0]), np.array([2.0])) + array([ 3.]) + + """ + + if isinstance(args, Symbol): + args = (args,) + else: + args = tuple(args) + + if language: + _validate_backend_language(backend, language) + else: + language = _infer_language(backend) + + helpers = helpers if helpers else () + flags = flags if flags else () + + if backend.upper() == 'NUMPY': + # maxargs is set by numpy compile-time constant NPY_MAXARGS + # If a future version of numpy modifies or removes this restriction + # this variable should be changed or removed + maxargs = 32 + helps = [] + for name, expr, args in helpers: + helps.append(make_routine(name, expr, args)) + code_wrapper = UfuncifyCodeWrapper(C99CodeGen("ufuncify"), tempdir, + flags, verbose) + if not isinstance(expr, (list, tuple)): + expr = [expr] + if len(expr) == 0: + raise ValueError('Expression iterable has zero length') + if len(expr) + len(args) > maxargs: + msg = ('Cannot create ufunc with more than {0} total arguments: ' + 'got {1} in, {2} out') + raise ValueError(msg.format(maxargs, len(args), len(expr))) + routines = [make_routine('autofunc{}'.format(idx), exprx, args) for + idx, exprx in enumerate(expr)] + return code_wrapper.wrap_code(routines, helpers=helps) + else: + # Dummies are used for all added expressions to prevent name clashes + # within the original expression. + y = IndexedBase(Dummy('y')) + m = Dummy('m', integer=True) + i = Idx(Dummy('i', integer=True), m) + f_dummy = Dummy('f') + f = implemented_function('%s_%d' % (f_dummy.name, f_dummy.dummy_index), Lambda(args, expr)) + # For each of the args create an indexed version. + indexed_args = [IndexedBase(Dummy(str(a))) for a in args] + # Order the arguments (out, args, dim) + args = [y] + indexed_args + [m] + args_with_indices = [a[i] for a in indexed_args] + return autowrap(Eq(y[i], f(*args_with_indices)), language, backend, + tempdir, args, flags, verbose, helpers, **kwargs) diff --git a/.venv/lib/python3.13/site-packages/sympy/utilities/codegen.py b/.venv/lib/python3.13/site-packages/sympy/utilities/codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..9ac8772fc000d707ae33c67eaa44b4c281157ab0 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/utilities/codegen.py @@ -0,0 +1,2237 @@ +""" +module for generating C, C++, Fortran77, Fortran90, Julia, Rust +and Octave/Matlab routines that evaluate SymPy expressions. +This module is work in progress. +Only the milestones with a '+' character in the list below have been completed. + +--- How is sympy.utilities.codegen different from sympy.printing.ccode? --- + +We considered the idea to extend the printing routines for SymPy functions in +such a way that it prints complete compilable code, but this leads to a few +unsurmountable issues that can only be tackled with dedicated code generator: + +- For C, one needs both a code and a header file, while the printing routines + generate just one string. This code generator can be extended to support + .pyf files for f2py. + +- SymPy functions are not concerned with programming-technical issues, such + as input, output and input-output arguments. Other examples are contiguous + or non-contiguous arrays, including headers of other libraries such as gsl + or others. + +- It is highly interesting to evaluate several SymPy functions in one C + routine, eventually sharing common intermediate results with the help + of the cse routine. This is more than just printing. + +- From the programming perspective, expressions with constants should be + evaluated in the code generator as much as possible. This is different + for printing. + +--- Basic assumptions --- + +* A generic Routine data structure describes the routine that must be + translated into C/Fortran/... code. This data structure covers all + features present in one or more of the supported languages. + +* Descendants from the CodeGen class transform multiple Routine instances + into compilable code. Each derived class translates into a specific + language. + +* In many cases, one wants a simple workflow. The friendly functions in the + last part are a simple api on top of the Routine/CodeGen stuff. They are + easier to use, but are less powerful. + +--- Milestones --- + ++ First working version with scalar input arguments, generating C code, + tests ++ Friendly functions that are easier to use than the rigorous + Routine/CodeGen workflow. ++ Integer and Real numbers as input and output ++ Output arguments ++ InputOutput arguments ++ Sort input/output arguments properly ++ Contiguous array arguments (numpy matrices) ++ Also generate .pyf code for f2py (in autowrap module) ++ Isolate constants and evaluate them beforehand in double precision ++ Fortran 90 ++ Octave/Matlab + +- Common Subexpression Elimination +- User defined comments in the generated code +- Optional extra include lines for libraries/objects that can eval special + functions +- Test other C compilers and libraries: gcc, tcc, libtcc, gcc+gsl, ... +- Contiguous array arguments (SymPy matrices) +- Non-contiguous array arguments (SymPy matrices) +- ccode must raise an error when it encounters something that cannot be + translated into c. ccode(integrate(sin(x)/x, x)) does not make sense. +- Complex numbers as input and output +- A default complex datatype +- Include extra information in the header: date, user, hostname, sha1 + hash, ... +- Fortran 77 +- C++ +- Python +- Julia +- Rust +- ... + +""" + +import os +import textwrap +from io import StringIO + +from sympy import __version__ as sympy_version +from sympy.core import Symbol, S, Tuple, Equality, Function, Basic +from sympy.printing.c import c_code_printers +from sympy.printing.codeprinter import AssignmentError +from sympy.printing.fortran import FCodePrinter +from sympy.printing.julia import JuliaCodePrinter +from sympy.printing.octave import OctaveCodePrinter +from sympy.printing.rust import RustCodePrinter +from sympy.tensor import Idx, Indexed, IndexedBase +from sympy.matrices import (MatrixSymbol, ImmutableMatrix, MatrixBase, + MatrixExpr, MatrixSlice) +from sympy.utilities.iterables import is_sequence + + +__all__ = [ + # description of routines + "Routine", "DataType", "default_datatypes", "get_default_datatype", + "Argument", "InputArgument", "OutputArgument", "Result", + # routines -> code + "CodeGen", "CCodeGen", "FCodeGen", "JuliaCodeGen", "OctaveCodeGen", + "RustCodeGen", + # friendly functions + "codegen", "make_routine", +] + + +# +# Description of routines +# + + +class Routine: + """Generic description of evaluation routine for set of expressions. + + A CodeGen class can translate instances of this class into code in a + particular language. The routine specification covers all the features + present in these languages. The CodeGen part must raise an exception + when certain features are not present in the target language. For + example, multiple return values are possible in Python, but not in C or + Fortran. Another example: Fortran and Python support complex numbers, + while C does not. + + """ + + def __init__(self, name, arguments, results, local_vars, global_vars): + """Initialize a Routine instance. + + Parameters + ========== + + name : string + Name of the routine. + + arguments : list of Arguments + These are things that appear in arguments of a routine, often + appearing on the right-hand side of a function call. These are + commonly InputArguments but in some languages, they can also be + OutputArguments or InOutArguments (e.g., pass-by-reference in C + code). + + results : list of Results + These are the return values of the routine, often appearing on + the left-hand side of a function call. The difference between + Results and OutputArguments and when you should use each is + language-specific. + + local_vars : list of Results + These are variables that will be defined at the beginning of the + function. + + global_vars : list of Symbols + Variables which will not be passed into the function. + + """ + + # extract all input symbols and all symbols appearing in an expression + input_symbols = set() + symbols = set() + for arg in arguments: + if isinstance(arg, OutputArgument): + symbols.update(arg.expr.free_symbols - arg.expr.atoms(Indexed)) + elif isinstance(arg, InputArgument): + input_symbols.add(arg.name) + elif isinstance(arg, InOutArgument): + input_symbols.add(arg.name) + symbols.update(arg.expr.free_symbols - arg.expr.atoms(Indexed)) + else: + raise ValueError("Unknown Routine argument: %s" % arg) + + for r in results: + if not isinstance(r, Result): + raise ValueError("Unknown Routine result: %s" % r) + symbols.update(r.expr.free_symbols - r.expr.atoms(Indexed)) + + local_symbols = set() + for r in local_vars: + if isinstance(r, Result): + symbols.update(r.expr.free_symbols - r.expr.atoms(Indexed)) + local_symbols.add(r.name) + else: + local_symbols.add(r) + + symbols = {s.label if isinstance(s, Idx) else s for s in symbols} + + # Check that all symbols in the expressions are covered by + # InputArguments/InOutArguments---subset because user could + # specify additional (unused) InputArguments or local_vars. + notcovered = symbols.difference( + input_symbols.union(local_symbols).union(global_vars)) + if notcovered != set(): + raise ValueError("Symbols needed for output are not in input " + + ", ".join([str(x) for x in notcovered])) + + self.name = name + self.arguments = arguments + self.results = results + self.local_vars = local_vars + self.global_vars = global_vars + + def __str__(self): + return self.__class__.__name__ + "({name!r}, {arguments}, {results}, {local_vars}, {global_vars})".format(**self.__dict__) + + __repr__ = __str__ + + @property + def variables(self): + """Returns a set of all variables possibly used in the routine. + + For routines with unnamed return values, the dummies that may or + may not be used will be included in the set. + + """ + v = set(self.local_vars) + v.update(arg.name for arg in self.arguments) + v.update(res.result_var for res in self.results) + return v + + @property + def result_variables(self): + """Returns a list of OutputArgument, InOutArgument and Result. + + If return values are present, they are at the end of the list. + """ + args = [arg for arg in self.arguments if isinstance( + arg, (OutputArgument, InOutArgument))] + args.extend(self.results) + return args + + +class DataType: + """Holds strings for a certain datatype in different languages.""" + def __init__(self, cname, fname, pyname, jlname, octname, rsname): + self.cname = cname + self.fname = fname + self.pyname = pyname + self.jlname = jlname + self.octname = octname + self.rsname = rsname + + +default_datatypes = { + "int": DataType("int", "INTEGER*4", "int", "", "", "i32"), + "float": DataType("double", "REAL*8", "float", "", "", "f64"), + "complex": DataType("double", "COMPLEX*16", "complex", "", "", "float") #FIXME: + # complex is only supported in fortran, python, julia, and octave. + # So to not break c or rust code generation, we stick with double or + # float, respectively (but actually should raise an exception for + # explicitly complex variables (x.is_complex==True)) +} + + +COMPLEX_ALLOWED = False +def get_default_datatype(expr, complex_allowed=None): + """Derives an appropriate datatype based on the expression.""" + if complex_allowed is None: + complex_allowed = COMPLEX_ALLOWED + if complex_allowed: + final_dtype = "complex" + else: + final_dtype = "float" + if expr.is_integer: + return default_datatypes["int"] + elif expr.is_real: + return default_datatypes["float"] + elif isinstance(expr, MatrixBase): + #check all entries + dt = "int" + for element in expr: + if dt == "int" and not element.is_integer: + dt = "float" + if dt == "float" and not element.is_real: + return default_datatypes[final_dtype] + return default_datatypes[dt] + else: + return default_datatypes[final_dtype] + + +class Variable: + """Represents a typed variable.""" + + def __init__(self, name, datatype=None, dimensions=None, precision=None): + """Return a new variable. + + Parameters + ========== + + name : Symbol or MatrixSymbol + + datatype : optional + When not given, the data type will be guessed based on the + assumptions on the symbol argument. + + dimensions : sequence containing tuples, optional + If present, the argument is interpreted as an array, where this + sequence of tuples specifies (lower, upper) bounds for each + index of the array. + + precision : int, optional + Controls the precision of floating point constants. + + """ + if not isinstance(name, (Symbol, MatrixSymbol)): + raise TypeError("The first argument must be a SymPy symbol.") + if datatype is None: + datatype = get_default_datatype(name) + elif not isinstance(datatype, DataType): + raise TypeError("The (optional) `datatype' argument must be an " + "instance of the DataType class.") + if dimensions and not isinstance(dimensions, (tuple, list)): + raise TypeError( + "The dimensions argument must be a sequence of tuples") + + self._name = name + self._datatype = { + 'C': datatype.cname, + 'FORTRAN': datatype.fname, + 'JULIA': datatype.jlname, + 'OCTAVE': datatype.octname, + 'PYTHON': datatype.pyname, + 'RUST': datatype.rsname, + } + self.dimensions = dimensions + self.precision = precision + + def __str__(self): + return "%s(%r)" % (self.__class__.__name__, self.name) + + __repr__ = __str__ + + @property + def name(self): + return self._name + + def get_datatype(self, language): + """Returns the datatype string for the requested language. + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.utilities.codegen import Variable + >>> x = Variable(Symbol('x')) + >>> x.get_datatype('c') + 'double' + >>> x.get_datatype('fortran') + 'REAL*8' + + """ + try: + return self._datatype[language.upper()] + except KeyError: + raise CodeGenError("Has datatypes for languages: %s" % + ", ".join(self._datatype)) + + +class Argument(Variable): + """An abstract Argument data structure: a name and a data type. + + This structure is refined in the descendants below. + + """ + pass + + +class InputArgument(Argument): + pass + + +class ResultBase: + """Base class for all "outgoing" information from a routine. + + Objects of this class stores a SymPy expression, and a SymPy object + representing a result variable that will be used in the generated code + only if necessary. + + """ + def __init__(self, expr, result_var): + self.expr = expr + self.result_var = result_var + + def __str__(self): + return "%s(%r, %r)" % (self.__class__.__name__, self.expr, + self.result_var) + + __repr__ = __str__ + + +class OutputArgument(Argument, ResultBase): + """OutputArgument are always initialized in the routine.""" + + def __init__(self, name, result_var, expr, datatype=None, dimensions=None, precision=None): + """Return a new variable. + + Parameters + ========== + + name : Symbol, MatrixSymbol + The name of this variable. When used for code generation, this + might appear, for example, in the prototype of function in the + argument list. + + result_var : Symbol, Indexed + Something that can be used to assign a value to this variable. + Typically the same as `name` but for Indexed this should be e.g., + "y[i]" whereas `name` should be the Symbol "y". + + expr : object + The expression that should be output, typically a SymPy + expression. + + datatype : optional + When not given, the data type will be guessed based on the + assumptions on the symbol argument. + + dimensions : sequence containing tuples, optional + If present, the argument is interpreted as an array, where this + sequence of tuples specifies (lower, upper) bounds for each + index of the array. + + precision : int, optional + Controls the precision of floating point constants. + + """ + + Argument.__init__(self, name, datatype, dimensions, precision) + ResultBase.__init__(self, expr, result_var) + + def __str__(self): + return "%s(%r, %r, %r)" % (self.__class__.__name__, self.name, self.result_var, self.expr) + + __repr__ = __str__ + + +class InOutArgument(Argument, ResultBase): + """InOutArgument are never initialized in the routine.""" + + def __init__(self, name, result_var, expr, datatype=None, dimensions=None, precision=None): + if not datatype: + datatype = get_default_datatype(expr) + Argument.__init__(self, name, datatype, dimensions, precision) + ResultBase.__init__(self, expr, result_var) + __init__.__doc__ = OutputArgument.__init__.__doc__ + + + def __str__(self): + return "%s(%r, %r, %r)" % (self.__class__.__name__, self.name, self.expr, + self.result_var) + + __repr__ = __str__ + + +class Result(Variable, ResultBase): + """An expression for a return value. + + The name result is used to avoid conflicts with the reserved word + "return" in the Python language. It is also shorter than ReturnValue. + + These may or may not need a name in the destination (e.g., "return(x*y)" + might return a value without ever naming it). + + """ + + def __init__(self, expr, name=None, result_var=None, datatype=None, + dimensions=None, precision=None): + """Initialize a return value. + + Parameters + ========== + + expr : SymPy expression + + name : Symbol, MatrixSymbol, optional + The name of this return variable. When used for code generation, + this might appear, for example, in the prototype of function in a + list of return values. A dummy name is generated if omitted. + + result_var : Symbol, Indexed, optional + Something that can be used to assign a value to this variable. + Typically the same as `name` but for Indexed this should be e.g., + "y[i]" whereas `name` should be the Symbol "y". Defaults to + `name` if omitted. + + datatype : optional + When not given, the data type will be guessed based on the + assumptions on the expr argument. + + dimensions : sequence containing tuples, optional + If present, this variable is interpreted as an array, + where this sequence of tuples specifies (lower, upper) + bounds for each index of the array. + + precision : int, optional + Controls the precision of floating point constants. + + """ + # Basic because it is the base class for all types of expressions + if not isinstance(expr, (Basic, MatrixBase)): + raise TypeError("The first argument must be a SymPy expression.") + + if name is None: + name = 'result_%d' % abs(hash(expr)) + + if datatype is None: + #try to infer data type from the expression + datatype = get_default_datatype(expr) + + if isinstance(name, str): + if isinstance(expr, (MatrixBase, MatrixExpr)): + name = MatrixSymbol(name, *expr.shape) + else: + name = Symbol(name) + + if result_var is None: + result_var = name + + Variable.__init__(self, name, datatype=datatype, + dimensions=dimensions, precision=precision) + ResultBase.__init__(self, expr, result_var) + + def __str__(self): + return "%s(%r, %r, %r)" % (self.__class__.__name__, self.expr, self.name, + self.result_var) + + __repr__ = __str__ + + +# +# Transformation of routine objects into code +# + +class CodeGen: + """Abstract class for the code generators.""" + + printer = None # will be set to an instance of a CodePrinter subclass + + def _indent_code(self, codelines): + return self.printer.indent_code(codelines) + + def _printer_method_with_settings(self, method, settings=None, *args, **kwargs): + settings = settings or {} + ori = {k: self.printer._settings[k] for k in settings} + for k, v in settings.items(): + self.printer._settings[k] = v + result = getattr(self.printer, method)(*args, **kwargs) + for k, v in ori.items(): + self.printer._settings[k] = v + return result + + def _get_symbol(self, s): + """Returns the symbol as fcode prints it.""" + if self.printer._settings['human']: + expr_str = self.printer.doprint(s) + else: + constants, not_supported, expr_str = self.printer.doprint(s) + if constants or not_supported: + raise ValueError("Failed to print %s" % str(s)) + return expr_str.strip() + + def __init__(self, project="project", cse=False): + """Initialize a code generator. + + Derived classes will offer more options that affect the generated + code. + + """ + self.project = project + self.cse = cse + + def routine(self, name, expr, argument_sequence=None, global_vars=None): + """Creates an Routine object that is appropriate for this language. + + This implementation is appropriate for at least C/Fortran. Subclasses + can override this if necessary. + + Here, we assume at most one return value (the l-value) which must be + scalar. Additional outputs are OutputArguments (e.g., pointers on + right-hand-side or pass-by-reference). Matrices are always returned + via OutputArguments. If ``argument_sequence`` is None, arguments will + be ordered alphabetically, but with all InputArguments first, and then + OutputArgument and InOutArguments. + + """ + + if self.cse: + from sympy.simplify.cse_main import cse + + if is_sequence(expr) and not isinstance(expr, (MatrixBase, MatrixExpr)): + if not expr: + raise ValueError("No expression given") + for e in expr: + if not e.is_Equality: + raise CodeGenError("Lists of expressions must all be Equalities. {} is not.".format(e)) + + # create a list of right hand sides and simplify them + rhs = [e.rhs for e in expr] + common, simplified = cse(rhs) + + # pack the simplified expressions back up with their left hand sides + expr = [Equality(e.lhs, rhs) for e, rhs in zip(expr, simplified)] + else: + if isinstance(expr, Equality): + common, simplified = cse(expr.rhs) #, ignore=in_out_args) + expr = Equality(expr.lhs, simplified[0]) + else: + common, simplified = cse(expr) + expr = simplified + + local_vars = [Result(b,a) for a,b in common] + local_symbols = {a for a,_ in common} + local_expressions = Tuple(*[b for _,b in common]) + else: + local_expressions = Tuple() + + if is_sequence(expr) and not isinstance(expr, (MatrixBase, MatrixExpr)): + if not expr: + raise ValueError("No expression given") + expressions = Tuple(*expr) + else: + expressions = Tuple(expr) + + if self.cse: + if {i.label for i in expressions.atoms(Idx)} != set(): + raise CodeGenError("CSE and Indexed expressions do not play well together yet") + else: + # local variables for indexed expressions + local_vars = {i.label for i in expressions.atoms(Idx)} + local_symbols = local_vars + + # global variables + global_vars = set() if global_vars is None else set(global_vars) + + # symbols that should be arguments + symbols = (expressions.free_symbols | local_expressions.free_symbols) - local_symbols - global_vars + new_symbols = set() + new_symbols.update(symbols) + + for symbol in symbols: + if isinstance(symbol, Idx): + new_symbols.remove(symbol) + new_symbols.update(symbol.args[1].free_symbols) + if isinstance(symbol, Indexed): + new_symbols.remove(symbol) + symbols = new_symbols + + # Decide whether to use output argument or return value + return_val = [] + output_args = [] + for expr in expressions: + if isinstance(expr, Equality): + out_arg = expr.lhs + expr = expr.rhs + if isinstance(out_arg, Indexed): + dims = tuple([ (S.Zero, dim - 1) for dim in out_arg.shape]) + symbol = out_arg.base.label + elif isinstance(out_arg, Symbol): + dims = [] + symbol = out_arg + elif isinstance(out_arg, MatrixSymbol): + dims = tuple([ (S.Zero, dim - 1) for dim in out_arg.shape]) + symbol = out_arg + else: + raise CodeGenError("Only Indexed, Symbol, or MatrixSymbol " + "can define output arguments.") + + if expr.has(symbol): + output_args.append( + InOutArgument(symbol, out_arg, expr, dimensions=dims)) + else: + output_args.append( + OutputArgument(symbol, out_arg, expr, dimensions=dims)) + + # remove duplicate arguments when they are not local variables + if symbol not in local_vars: + # avoid duplicate arguments + symbols.remove(symbol) + elif isinstance(expr, (ImmutableMatrix, MatrixSlice)): + # Create a "dummy" MatrixSymbol to use as the Output arg + out_arg = MatrixSymbol('out_%s' % abs(hash(expr)), *expr.shape) + dims = tuple([(S.Zero, dim - 1) for dim in out_arg.shape]) + output_args.append( + OutputArgument(out_arg, out_arg, expr, dimensions=dims)) + else: + return_val.append(Result(expr)) + + arg_list = [] + + # setup input argument list + + # helper to get dimensions for data for array-like args + def dimensions(s): + return [(S.Zero, dim - 1) for dim in s.shape] + + array_symbols = {} + for array in expressions.atoms(Indexed) | local_expressions.atoms(Indexed): + array_symbols[array.base.label] = array + for array in expressions.atoms(MatrixSymbol) | local_expressions.atoms(MatrixSymbol): + array_symbols[array] = array + + for symbol in sorted(symbols, key=str): + if symbol in array_symbols: + array = array_symbols[symbol] + metadata = {'dimensions': dimensions(array)} + else: + metadata = {} + + arg_list.append(InputArgument(symbol, **metadata)) + + output_args.sort(key=lambda x: str(x.name)) + arg_list.extend(output_args) + + if argument_sequence is not None: + # if the user has supplied IndexedBase instances, we'll accept that + new_sequence = [] + for arg in argument_sequence: + if isinstance(arg, IndexedBase): + new_sequence.append(arg.label) + else: + new_sequence.append(arg) + argument_sequence = new_sequence + + missing = [x for x in arg_list if x.name not in argument_sequence] + if missing: + msg = "Argument list didn't specify: {0} " + msg = msg.format(", ".join([str(m.name) for m in missing])) + raise CodeGenArgumentListError(msg, missing) + + # create redundant arguments to produce the requested sequence + name_arg_dict = {x.name: x for x in arg_list} + new_args = [] + for symbol in argument_sequence: + try: + new_args.append(name_arg_dict[symbol]) + except KeyError: + if isinstance(symbol, (IndexedBase, MatrixSymbol)): + metadata = {'dimensions': dimensions(symbol)} + else: + metadata = {} + new_args.append(InputArgument(symbol, **metadata)) + arg_list = new_args + + return Routine(name, arg_list, return_val, local_vars, global_vars) + + def write(self, routines, prefix, to_files=False, header=True, empty=True): + """Writes all the source code files for the given routines. + + The generated source is returned as a list of (filename, contents) + tuples, or is written to files (see below). Each filename consists + of the given prefix, appended with an appropriate extension. + + Parameters + ========== + + routines : list + A list of Routine instances to be written + + prefix : string + The prefix for the output files + + to_files : bool, optional + When True, the output is written to files. Otherwise, a list + of (filename, contents) tuples is returned. [default: False] + + header : bool, optional + When True, a header comment is included on top of each source + file. [default: True] + + empty : bool, optional + When True, empty lines are included to structure the source + files. [default: True] + + """ + if to_files: + for dump_fn in self.dump_fns: + filename = "%s.%s" % (prefix, dump_fn.extension) + with open(filename, "w") as f: + dump_fn(self, routines, f, prefix, header, empty) + else: + result = [] + for dump_fn in self.dump_fns: + filename = "%s.%s" % (prefix, dump_fn.extension) + contents = StringIO() + dump_fn(self, routines, contents, prefix, header, empty) + result.append((filename, contents.getvalue())) + return result + + def dump_code(self, routines, f, prefix, header=True, empty=True): + """Write the code by calling language specific methods. + + The generated file contains all the definitions of the routines in + low-level code and refers to the header file if appropriate. + + Parameters + ========== + + routines : list + A list of Routine instances. + + f : file-like + Where to write the file. + + prefix : string + The filename prefix, used to refer to the proper header file. + Only the basename of the prefix is used. + + header : bool, optional + When True, a header comment is included on top of each source + file. [default : True] + + empty : bool, optional + When True, empty lines are included to structure the source + files. [default : True] + + """ + + code_lines = self._preprocessor_statements(prefix) + + for routine in routines: + if empty: + code_lines.append("\n") + code_lines.extend(self._get_routine_opening(routine)) + code_lines.extend(self._declare_arguments(routine)) + code_lines.extend(self._declare_globals(routine)) + code_lines.extend(self._declare_locals(routine)) + if empty: + code_lines.append("\n") + code_lines.extend(self._call_printer(routine)) + if empty: + code_lines.append("\n") + code_lines.extend(self._get_routine_ending(routine)) + + code_lines = self._indent_code(''.join(code_lines)) + + if header: + code_lines = ''.join(self._get_header() + [code_lines]) + + if code_lines: + f.write(code_lines) + + +class CodeGenError(Exception): + pass + + +class CodeGenArgumentListError(Exception): + @property + def missing_args(self): + return self.args[1] + + +header_comment = """Code generated with SymPy %(version)s + +See http://www.sympy.org/ for more information. + +This file is part of '%(project)s' +""" + + +class CCodeGen(CodeGen): + """Generator for C code. + + The .write() method inherited from CodeGen will output a code file and + an interface file, .c and .h respectively. + + """ + + code_extension = "c" + interface_extension = "h" + standard = 'c99' + + def __init__(self, project="project", printer=None, + preprocessor_statements=None, cse=False): + super().__init__(project=project, cse=cse) + self.printer = printer or c_code_printers[self.standard.lower()]() + + self.preprocessor_statements = preprocessor_statements + if preprocessor_statements is None: + self.preprocessor_statements = ['#include '] + + def _get_header(self): + """Writes a common header for the generated files.""" + code_lines = [] + code_lines.append("/" + "*"*78 + '\n') + tmp = header_comment % {"version": sympy_version, + "project": self.project} + for line in tmp.splitlines(): + code_lines.append(" *%s*\n" % line.center(76)) + code_lines.append(" " + "*"*78 + "/\n") + return code_lines + + def get_prototype(self, routine): + """Returns a string for the function prototype of the routine. + + If the routine has multiple result objects, an CodeGenError is + raised. + + See: https://en.wikipedia.org/wiki/Function_prototype + + """ + if len(routine.results) > 1: + raise CodeGenError("C only supports a single or no return value.") + elif len(routine.results) == 1: + ctype = routine.results[0].get_datatype('C') + else: + ctype = "void" + + type_args = [] + for arg in routine.arguments: + name = self.printer.doprint(arg.name) + if arg.dimensions or isinstance(arg, ResultBase): + type_args.append((arg.get_datatype('C'), "*%s" % name)) + else: + type_args.append((arg.get_datatype('C'), name)) + arguments = ", ".join([ "%s %s" % t for t in type_args]) + return "%s %s(%s)" % (ctype, routine.name, arguments) + + def _preprocessor_statements(self, prefix): + code_lines = [] + code_lines.append('#include "{}.h"'.format(os.path.basename(prefix))) + code_lines.extend(self.preprocessor_statements) + code_lines = ['{}\n'.format(l) for l in code_lines] + return code_lines + + def _get_routine_opening(self, routine): + prototype = self.get_prototype(routine) + return ["%s {\n" % prototype] + + def _declare_arguments(self, routine): + # arguments are declared in prototype + return [] + + def _declare_globals(self, routine): + # global variables are not explicitly declared within C functions + return [] + + def _declare_locals(self, routine): + + # Compose a list of symbols to be dereferenced in the function + # body. These are the arguments that were passed by a reference + # pointer, excluding arrays. + dereference = [] + for arg in routine.arguments: + if isinstance(arg, ResultBase) and not arg.dimensions: + dereference.append(arg.name) + + code_lines = [] + for result in routine.local_vars: + + # local variables that are simple symbols such as those used as indices into + # for loops are defined declared elsewhere. + if not isinstance(result, Result): + continue + + if result.name != result.result_var: + raise CodeGen("Result variable and name should match: {}".format(result)) + assign_to = result.name + t = result.get_datatype('c') + if isinstance(result.expr, (MatrixBase, MatrixExpr)): + dims = result.expr.shape + code_lines.append("{} {}[{}];\n".format(t, str(assign_to), dims[0]*dims[1])) + prefix = "" + else: + prefix = "const {} ".format(t) + + constants, not_c, c_expr = self._printer_method_with_settings( + 'doprint', {"human": False, "dereference": dereference, "strict": False}, + result.expr, assign_to=assign_to) + + for name, value in sorted(constants, key=str): + code_lines.append("double const %s = %s;\n" % (name, value)) + + code_lines.append("{}{}\n".format(prefix, c_expr)) + + return code_lines + + def _call_printer(self, routine): + code_lines = [] + + # Compose a list of symbols to be dereferenced in the function + # body. These are the arguments that were passed by a reference + # pointer, excluding arrays. + dereference = [] + for arg in routine.arguments: + if isinstance(arg, ResultBase) and not arg.dimensions: + dereference.append(arg.name) + + return_val = None + for result in routine.result_variables: + if isinstance(result, Result): + assign_to = routine.name + "_result" + t = result.get_datatype('c') + code_lines.append("{} {};\n".format(t, str(assign_to))) + return_val = assign_to + else: + assign_to = result.result_var + + try: + constants, not_c, c_expr = self._printer_method_with_settings( + 'doprint', {"human": False, "dereference": dereference, "strict": False}, + result.expr, assign_to=assign_to) + except AssignmentError: + assign_to = result.result_var + code_lines.append( + "%s %s;\n" % (result.get_datatype('c'), str(assign_to))) + constants, not_c, c_expr = self._printer_method_with_settings( + 'doprint', {"human": False, "dereference": dereference, "strict": False}, + result.expr, assign_to=assign_to) + + for name, value in sorted(constants, key=str): + code_lines.append("double const %s = %s;\n" % (name, value)) + code_lines.append("%s\n" % c_expr) + + if return_val: + code_lines.append(" return %s;\n" % return_val) + return code_lines + + def _get_routine_ending(self, routine): + return ["}\n"] + + def dump_c(self, routines, f, prefix, header=True, empty=True): + self.dump_code(routines, f, prefix, header, empty) + dump_c.extension = code_extension # type: ignore + dump_c.__doc__ = CodeGen.dump_code.__doc__ + + def dump_h(self, routines, f, prefix, header=True, empty=True): + """Writes the C header file. + + This file contains all the function declarations. + + Parameters + ========== + + routines : list + A list of Routine instances. + + f : file-like + Where to write the file. + + prefix : string + The filename prefix, used to construct the include guards. + Only the basename of the prefix is used. + + header : bool, optional + When True, a header comment is included on top of each source + file. [default : True] + + empty : bool, optional + When True, empty lines are included to structure the source + files. [default : True] + + """ + if header: + print(''.join(self._get_header()), file=f) + guard_name = "%s__%s__H" % (self.project.replace( + " ", "_").upper(), prefix.replace("/", "_").upper()) + # include guards + if empty: + print(file=f) + print("#ifndef %s" % guard_name, file=f) + print("#define %s" % guard_name, file=f) + if empty: + print(file=f) + # declaration of the function prototypes + for routine in routines: + prototype = self.get_prototype(routine) + print("%s;" % prototype, file=f) + # end if include guards + if empty: + print(file=f) + print("#endif", file=f) + if empty: + print(file=f) + dump_h.extension = interface_extension # type: ignore + + # This list of dump functions is used by CodeGen.write to know which dump + # functions it has to call. + dump_fns = [dump_c, dump_h] + +class C89CodeGen(CCodeGen): + standard = 'C89' + +class C99CodeGen(CCodeGen): + standard = 'C99' + +class FCodeGen(CodeGen): + """Generator for Fortran 95 code + + The .write() method inherited from CodeGen will output a code file and + an interface file, .f90 and .h respectively. + + """ + + code_extension = "f90" + interface_extension = "h" + + def __init__(self, project='project', printer=None): + super().__init__(project) + self.printer = printer or FCodePrinter() + + def _get_header(self): + """Writes a common header for the generated files.""" + code_lines = [] + code_lines.append("!" + "*"*78 + '\n') + tmp = header_comment % {"version": sympy_version, + "project": self.project} + for line in tmp.splitlines(): + code_lines.append("!*%s*\n" % line.center(76)) + code_lines.append("!" + "*"*78 + '\n') + return code_lines + + def _preprocessor_statements(self, prefix): + return [] + + def _get_routine_opening(self, routine): + """Returns the opening statements of the fortran routine.""" + code_list = [] + if len(routine.results) > 1: + raise CodeGenError( + "Fortran only supports a single or no return value.") + elif len(routine.results) == 1: + result = routine.results[0] + code_list.append(result.get_datatype('fortran')) + code_list.append("function") + else: + code_list.append("subroutine") + + args = ", ".join("%s" % self._get_symbol(arg.name) + for arg in routine.arguments) + + call_sig = "{}({})\n".format(routine.name, args) + # Fortran 95 requires all lines be less than 132 characters, so wrap + # this line before appending. + call_sig = ' &\n'.join(textwrap.wrap(call_sig, + width=60, + break_long_words=False)) + '\n' + code_list.append(call_sig) + code_list = [' '.join(code_list)] + code_list.append('implicit none\n') + return code_list + + def _declare_arguments(self, routine): + # argument type declarations + code_list = [] + array_list = [] + scalar_list = [] + for arg in routine.arguments: + + if isinstance(arg, InputArgument): + typeinfo = "%s, intent(in)" % arg.get_datatype('fortran') + elif isinstance(arg, InOutArgument): + typeinfo = "%s, intent(inout)" % arg.get_datatype('fortran') + elif isinstance(arg, OutputArgument): + typeinfo = "%s, intent(out)" % arg.get_datatype('fortran') + else: + raise CodeGenError("Unknown Argument type: %s" % type(arg)) + + fprint = self._get_symbol + + if arg.dimensions: + # fortran arrays start at 1 + dimstr = ", ".join(["%s:%s" % ( + fprint(dim[0] + 1), fprint(dim[1] + 1)) + for dim in arg.dimensions]) + typeinfo += ", dimension(%s)" % dimstr + array_list.append("%s :: %s\n" % (typeinfo, fprint(arg.name))) + else: + scalar_list.append("%s :: %s\n" % (typeinfo, fprint(arg.name))) + + # scalars first, because they can be used in array declarations + code_list.extend(scalar_list) + code_list.extend(array_list) + + return code_list + + def _declare_globals(self, routine): + # Global variables not explicitly declared within Fortran 90 functions. + # Note: a future F77 mode may need to generate "common" blocks. + return [] + + def _declare_locals(self, routine): + code_list = [] + for var in sorted(routine.local_vars, key=str): + typeinfo = get_default_datatype(var) + code_list.append("%s :: %s\n" % ( + typeinfo.fname, self._get_symbol(var))) + return code_list + + def _get_routine_ending(self, routine): + """Returns the closing statements of the fortran routine.""" + if len(routine.results) == 1: + return ["end function\n"] + else: + return ["end subroutine\n"] + + def get_interface(self, routine): + """Returns a string for the function interface. + + The routine should have a single result object, which can be None. + If the routine has multiple result objects, a CodeGenError is + raised. + + See: https://en.wikipedia.org/wiki/Function_prototype + + """ + prototype = [ "interface\n" ] + prototype.extend(self._get_routine_opening(routine)) + prototype.extend(self._declare_arguments(routine)) + prototype.extend(self._get_routine_ending(routine)) + prototype.append("end interface\n") + + return "".join(prototype) + + def _call_printer(self, routine): + declarations = [] + code_lines = [] + for result in routine.result_variables: + if isinstance(result, Result): + assign_to = routine.name + elif isinstance(result, (OutputArgument, InOutArgument)): + assign_to = result.result_var + + constants, not_fortran, f_expr = self._printer_method_with_settings( + 'doprint', {"human": False, "source_format": 'free', "standard": 95, "strict": False}, + result.expr, assign_to=assign_to) + + for obj, v in sorted(constants, key=str): + t = get_default_datatype(obj) + declarations.append( + "%s, parameter :: %s = %s\n" % (t.fname, obj, v)) + for obj in sorted(not_fortran, key=str): + t = get_default_datatype(obj) + if isinstance(obj, Function): + name = obj.func + else: + name = obj + declarations.append("%s :: %s\n" % (t.fname, name)) + + code_lines.append("%s\n" % f_expr) + return declarations + code_lines + + def _indent_code(self, codelines): + return self._printer_method_with_settings( + 'indent_code', {"human": False, "source_format": 'free', "strict": False}, codelines) + + def dump_f95(self, routines, f, prefix, header=True, empty=True): + # check that symbols are unique with ignorecase + for r in routines: + lowercase = {str(x).lower() for x in r.variables} + orig_case = {str(x) for x in r.variables} + if len(lowercase) < len(orig_case): + raise CodeGenError("Fortran ignores case. Got symbols: %s" % + (", ".join([str(var) for var in r.variables]))) + self.dump_code(routines, f, prefix, header, empty) + dump_f95.extension = code_extension # type: ignore + dump_f95.__doc__ = CodeGen.dump_code.__doc__ + + def dump_h(self, routines, f, prefix, header=True, empty=True): + """Writes the interface to a header file. + + This file contains all the function declarations. + + Parameters + ========== + + routines : list + A list of Routine instances. + + f : file-like + Where to write the file. + + prefix : string + The filename prefix. + + header : bool, optional + When True, a header comment is included on top of each source + file. [default : True] + + empty : bool, optional + When True, empty lines are included to structure the source + files. [default : True] + + """ + if header: + print(''.join(self._get_header()), file=f) + if empty: + print(file=f) + # declaration of the function prototypes + for routine in routines: + prototype = self.get_interface(routine) + f.write(prototype) + if empty: + print(file=f) + dump_h.extension = interface_extension # type: ignore + + # This list of dump functions is used by CodeGen.write to know which dump + # functions it has to call. + dump_fns = [dump_f95, dump_h] + + +class JuliaCodeGen(CodeGen): + """Generator for Julia code. + + The .write() method inherited from CodeGen will output a code file + .jl. + + """ + + code_extension = "jl" + + def __init__(self, project='project', printer=None): + super().__init__(project) + self.printer = printer or JuliaCodePrinter() + + def routine(self, name, expr, argument_sequence, global_vars): + """Specialized Routine creation for Julia.""" + + if is_sequence(expr) and not isinstance(expr, (MatrixBase, MatrixExpr)): + if not expr: + raise ValueError("No expression given") + expressions = Tuple(*expr) + else: + expressions = Tuple(expr) + + # local variables + local_vars = {i.label for i in expressions.atoms(Idx)} + + # global variables + global_vars = set() if global_vars is None else set(global_vars) + + # symbols that should be arguments + old_symbols = expressions.free_symbols - local_vars - global_vars + symbols = set() + for s in old_symbols: + if isinstance(s, Idx): + symbols.update(s.args[1].free_symbols) + elif not isinstance(s, Indexed): + symbols.add(s) + + # Julia supports multiple return values + return_vals = [] + output_args = [] + for (i, expr) in enumerate(expressions): + if isinstance(expr, Equality): + out_arg = expr.lhs + expr = expr.rhs + symbol = out_arg + if isinstance(out_arg, Indexed): + dims = tuple([ (S.One, dim) for dim in out_arg.shape]) + symbol = out_arg.base.label + output_args.append(InOutArgument(symbol, out_arg, expr, dimensions=dims)) + if not isinstance(out_arg, (Indexed, Symbol, MatrixSymbol)): + raise CodeGenError("Only Indexed, Symbol, or MatrixSymbol " + "can define output arguments.") + + return_vals.append(Result(expr, name=symbol, result_var=out_arg)) + if not expr.has(symbol): + # this is a pure output: remove from the symbols list, so + # it doesn't become an input. + symbols.remove(symbol) + + else: + # we have no name for this output + return_vals.append(Result(expr, name='out%d' % (i+1))) + + # setup input argument list + output_args.sort(key=lambda x: str(x.name)) + arg_list = list(output_args) + array_symbols = {} + for array in expressions.atoms(Indexed): + array_symbols[array.base.label] = array + for array in expressions.atoms(MatrixSymbol): + array_symbols[array] = array + + for symbol in sorted(symbols, key=str): + arg_list.append(InputArgument(symbol)) + + if argument_sequence is not None: + # if the user has supplied IndexedBase instances, we'll accept that + new_sequence = [] + for arg in argument_sequence: + if isinstance(arg, IndexedBase): + new_sequence.append(arg.label) + else: + new_sequence.append(arg) + argument_sequence = new_sequence + + missing = [x for x in arg_list if x.name not in argument_sequence] + if missing: + msg = "Argument list didn't specify: {0} " + msg = msg.format(", ".join([str(m.name) for m in missing])) + raise CodeGenArgumentListError(msg, missing) + + # create redundant arguments to produce the requested sequence + name_arg_dict = {x.name: x for x in arg_list} + new_args = [] + for symbol in argument_sequence: + try: + new_args.append(name_arg_dict[symbol]) + except KeyError: + new_args.append(InputArgument(symbol)) + arg_list = new_args + + return Routine(name, arg_list, return_vals, local_vars, global_vars) + + def _get_header(self): + """Writes a common header for the generated files.""" + code_lines = [] + tmp = header_comment % {"version": sympy_version, + "project": self.project} + for line in tmp.splitlines(): + if line == '': + code_lines.append("#\n") + else: + code_lines.append("# %s\n" % line) + return code_lines + + def _preprocessor_statements(self, prefix): + return [] + + def _get_routine_opening(self, routine): + """Returns the opening statements of the routine.""" + code_list = [] + code_list.append("function ") + + # Inputs + args = [] + for arg in routine.arguments: + if isinstance(arg, OutputArgument): + raise CodeGenError("Julia: invalid argument of type %s" % + str(type(arg))) + if isinstance(arg, (InputArgument, InOutArgument)): + args.append("%s" % self._get_symbol(arg.name)) + args = ", ".join(args) + code_list.append("%s(%s)\n" % (routine.name, args)) + code_list = [ "".join(code_list) ] + + return code_list + + def _declare_arguments(self, routine): + return [] + + def _declare_globals(self, routine): + return [] + + def _declare_locals(self, routine): + return [] + + def _get_routine_ending(self, routine): + outs = [] + for result in routine.results: + if isinstance(result, Result): + # Note: name not result_var; want `y` not `y[i]` for Indexed + s = self._get_symbol(result.name) + else: + raise CodeGenError("unexpected object in Routine results") + outs.append(s) + return ["return " + ", ".join(outs) + "\nend\n"] + + def _call_printer(self, routine): + declarations = [] + code_lines = [] + for result in routine.results: + if isinstance(result, Result): + assign_to = result.result_var + else: + raise CodeGenError("unexpected object in Routine results") + + constants, not_supported, jl_expr = self._printer_method_with_settings( + 'doprint', {"human": False, "strict": False}, result.expr, assign_to=assign_to) + + for obj, v in sorted(constants, key=str): + declarations.append( + "%s = %s\n" % (obj, v)) + for obj in sorted(not_supported, key=str): + if isinstance(obj, Function): + name = obj.func + else: + name = obj + declarations.append( + "# unsupported: %s\n" % (name)) + code_lines.append("%s\n" % (jl_expr)) + return declarations + code_lines + + def _indent_code(self, codelines): + # Note that indenting seems to happen twice, first + # statement-by-statement by JuliaPrinter then again here. + p = JuliaCodePrinter({'human': False, "strict": False}) + return p.indent_code(codelines) + + def dump_jl(self, routines, f, prefix, header=True, empty=True): + self.dump_code(routines, f, prefix, header, empty) + + dump_jl.extension = code_extension # type: ignore + dump_jl.__doc__ = CodeGen.dump_code.__doc__ + + # This list of dump functions is used by CodeGen.write to know which dump + # functions it has to call. + dump_fns = [dump_jl] + + +class OctaveCodeGen(CodeGen): + """Generator for Octave code. + + The .write() method inherited from CodeGen will output a code file + .m. + + Octave .m files usually contain one function. That function name should + match the filename (``prefix``). If you pass multiple ``name_expr`` pairs, + the latter ones are presumed to be private functions accessed by the + primary function. + + You should only pass inputs to ``argument_sequence``: outputs are ordered + according to their order in ``name_expr``. + + """ + + code_extension = "m" + + def __init__(self, project='project', printer=None): + super().__init__(project) + self.printer = printer or OctaveCodePrinter() + + def routine(self, name, expr, argument_sequence, global_vars): + """Specialized Routine creation for Octave.""" + + # FIXME: this is probably general enough for other high-level + # languages, perhaps its the C/Fortran one that is specialized! + + if is_sequence(expr) and not isinstance(expr, (MatrixBase, MatrixExpr)): + if not expr: + raise ValueError("No expression given") + expressions = Tuple(*expr) + else: + expressions = Tuple(expr) + + # local variables + local_vars = {i.label for i in expressions.atoms(Idx)} + + # global variables + global_vars = set() if global_vars is None else set(global_vars) + + # symbols that should be arguments + old_symbols = expressions.free_symbols - local_vars - global_vars + symbols = set() + for s in old_symbols: + if isinstance(s, Idx): + symbols.update(s.args[1].free_symbols) + elif not isinstance(s, Indexed): + symbols.add(s) + + # Octave supports multiple return values + return_vals = [] + for (i, expr) in enumerate(expressions): + if isinstance(expr, Equality): + out_arg = expr.lhs + expr = expr.rhs + symbol = out_arg + if isinstance(out_arg, Indexed): + symbol = out_arg.base.label + if not isinstance(out_arg, (Indexed, Symbol, MatrixSymbol)): + raise CodeGenError("Only Indexed, Symbol, or MatrixSymbol " + "can define output arguments.") + + return_vals.append(Result(expr, name=symbol, result_var=out_arg)) + if not expr.has(symbol): + # this is a pure output: remove from the symbols list, so + # it doesn't become an input. + symbols.remove(symbol) + + else: + # we have no name for this output + return_vals.append(Result(expr, name='out%d' % (i+1))) + + # setup input argument list + arg_list = [] + array_symbols = {} + for array in expressions.atoms(Indexed): + array_symbols[array.base.label] = array + for array in expressions.atoms(MatrixSymbol): + array_symbols[array] = array + + for symbol in sorted(symbols, key=str): + arg_list.append(InputArgument(symbol)) + + if argument_sequence is not None: + # if the user has supplied IndexedBase instances, we'll accept that + new_sequence = [] + for arg in argument_sequence: + if isinstance(arg, IndexedBase): + new_sequence.append(arg.label) + else: + new_sequence.append(arg) + argument_sequence = new_sequence + + missing = [x for x in arg_list if x.name not in argument_sequence] + if missing: + msg = "Argument list didn't specify: {0} " + msg = msg.format(", ".join([str(m.name) for m in missing])) + raise CodeGenArgumentListError(msg, missing) + + # create redundant arguments to produce the requested sequence + name_arg_dict = {x.name: x for x in arg_list} + new_args = [] + for symbol in argument_sequence: + try: + new_args.append(name_arg_dict[symbol]) + except KeyError: + new_args.append(InputArgument(symbol)) + arg_list = new_args + + return Routine(name, arg_list, return_vals, local_vars, global_vars) + + def _get_header(self): + """Writes a common header for the generated files.""" + code_lines = [] + tmp = header_comment % {"version": sympy_version, + "project": self.project} + for line in tmp.splitlines(): + if line == '': + code_lines.append("%\n") + else: + code_lines.append("%% %s\n" % line) + return code_lines + + def _preprocessor_statements(self, prefix): + return [] + + def _get_routine_opening(self, routine): + """Returns the opening statements of the routine.""" + code_list = [] + code_list.append("function ") + + # Outputs + outs = [] + for result in routine.results: + if isinstance(result, Result): + # Note: name not result_var; want `y` not `y(i)` for Indexed + s = self._get_symbol(result.name) + else: + raise CodeGenError("unexpected object in Routine results") + outs.append(s) + if len(outs) > 1: + code_list.append("[" + (", ".join(outs)) + "]") + else: + code_list.append("".join(outs)) + code_list.append(" = ") + + # Inputs + args = [] + for arg in routine.arguments: + if isinstance(arg, (OutputArgument, InOutArgument)): + raise CodeGenError("Octave: invalid argument of type %s" % + str(type(arg))) + if isinstance(arg, InputArgument): + args.append("%s" % self._get_symbol(arg.name)) + args = ", ".join(args) + code_list.append("%s(%s)\n" % (routine.name, args)) + code_list = [ "".join(code_list) ] + + return code_list + + def _declare_arguments(self, routine): + return [] + + def _declare_globals(self, routine): + if not routine.global_vars: + return [] + s = " ".join(sorted([self._get_symbol(g) for g in routine.global_vars])) + return ["global " + s + "\n"] + + def _declare_locals(self, routine): + return [] + + def _get_routine_ending(self, routine): + return ["end\n"] + + def _call_printer(self, routine): + declarations = [] + code_lines = [] + for result in routine.results: + if isinstance(result, Result): + assign_to = result.result_var + else: + raise CodeGenError("unexpected object in Routine results") + + constants, not_supported, oct_expr = self._printer_method_with_settings( + 'doprint', {"human": False, "strict": False}, result.expr, assign_to=assign_to) + + for obj, v in sorted(constants, key=str): + declarations.append( + " %s = %s; %% constant\n" % (obj, v)) + for obj in sorted(not_supported, key=str): + if isinstance(obj, Function): + name = obj.func + else: + name = obj + declarations.append( + " %% unsupported: %s\n" % (name)) + code_lines.append("%s\n" % (oct_expr)) + return declarations + code_lines + + def _indent_code(self, codelines): + return self._printer_method_with_settings( + 'indent_code', {"human": False, "strict": False}, codelines) + + def dump_m(self, routines, f, prefix, header=True, empty=True, inline=True): + # Note used to call self.dump_code() but we need more control for header + + code_lines = self._preprocessor_statements(prefix) + + for i, routine in enumerate(routines): + if i > 0: + if empty: + code_lines.append("\n") + code_lines.extend(self._get_routine_opening(routine)) + if i == 0: + if routine.name != prefix: + raise ValueError('Octave function name should match prefix') + if header: + code_lines.append("%" + prefix.upper() + + " Autogenerated by SymPy\n") + code_lines.append(''.join(self._get_header())) + code_lines.extend(self._declare_arguments(routine)) + code_lines.extend(self._declare_globals(routine)) + code_lines.extend(self._declare_locals(routine)) + if empty: + code_lines.append("\n") + code_lines.extend(self._call_printer(routine)) + if empty: + code_lines.append("\n") + code_lines.extend(self._get_routine_ending(routine)) + + code_lines = self._indent_code(''.join(code_lines)) + + if code_lines: + f.write(code_lines) + + dump_m.extension = code_extension # type: ignore + dump_m.__doc__ = CodeGen.dump_code.__doc__ + + # This list of dump functions is used by CodeGen.write to know which dump + # functions it has to call. + dump_fns = [dump_m] + +class RustCodeGen(CodeGen): + """Generator for Rust code. + + The .write() method inherited from CodeGen will output a code file + .rs + + """ + + code_extension = "rs" + + def __init__(self, project="project", printer=None): + super().__init__(project=project) + self.printer = printer or RustCodePrinter() + + def routine(self, name, expr, argument_sequence, global_vars): + """Specialized Routine creation for Rust.""" + + if is_sequence(expr) and not isinstance(expr, (MatrixBase, MatrixExpr)): + if not expr: + raise ValueError("No expression given") + expressions = Tuple(*expr) + else: + expressions = Tuple(expr) + + # local variables + local_vars = {i.label for i in expressions.atoms(Idx)} + + # global variables + global_vars = set() if global_vars is None else set(global_vars) + + # symbols that should be arguments + symbols = expressions.free_symbols - local_vars - global_vars - expressions.atoms(Indexed) + + # Rust supports multiple return values + return_vals = [] + output_args = [] + for (i, expr) in enumerate(expressions): + if isinstance(expr, Equality): + out_arg = expr.lhs + expr = expr.rhs + symbol = out_arg + if isinstance(out_arg, Indexed): + dims = tuple([ (S.One, dim) for dim in out_arg.shape]) + symbol = out_arg.base.label + output_args.append(InOutArgument(symbol, out_arg, expr, dimensions=dims)) + if not isinstance(out_arg, (Indexed, Symbol, MatrixSymbol)): + raise CodeGenError("Only Indexed, Symbol, or MatrixSymbol " + "can define output arguments.") + + return_vals.append(Result(expr, name=symbol, result_var=out_arg)) + if not expr.has(symbol): + # this is a pure output: remove from the symbols list, so + # it doesn't become an input. + symbols.remove(symbol) + + else: + # we have no name for this output + return_vals.append(Result(expr, name='out%d' % (i+1))) + + # setup input argument list + output_args.sort(key=lambda x: str(x.name)) + arg_list = list(output_args) + array_symbols = {} + for array in expressions.atoms(Indexed): + array_symbols[array.base.label] = array + for array in expressions.atoms(MatrixSymbol): + array_symbols[array] = array + + for symbol in sorted(symbols, key=str): + arg_list.append(InputArgument(symbol)) + + if argument_sequence is not None: + # if the user has supplied IndexedBase instances, we'll accept that + new_sequence = [] + for arg in argument_sequence: + if isinstance(arg, IndexedBase): + new_sequence.append(arg.label) + else: + new_sequence.append(arg) + argument_sequence = new_sequence + + missing = [x for x in arg_list if x.name not in argument_sequence] + if missing: + msg = "Argument list didn't specify: {0} " + msg = msg.format(", ".join([str(m.name) for m in missing])) + raise CodeGenArgumentListError(msg, missing) + + # create redundant arguments to produce the requested sequence + name_arg_dict = {x.name: x for x in arg_list} + new_args = [] + for symbol in argument_sequence: + try: + new_args.append(name_arg_dict[symbol]) + except KeyError: + new_args.append(InputArgument(symbol)) + arg_list = new_args + + return Routine(name, arg_list, return_vals, local_vars, global_vars) + + + def _get_header(self): + """Writes a common header for the generated files.""" + code_lines = [] + code_lines.append("/*\n") + tmp = header_comment % {"version": sympy_version, + "project": self.project} + for line in tmp.splitlines(): + code_lines.append((" *%s" % line.center(76)).rstrip() + "\n") + code_lines.append(" */\n") + return code_lines + + def get_prototype(self, routine): + """Returns a string for the function prototype of the routine. + + If the routine has multiple result objects, an CodeGenError is + raised. + + See: https://en.wikipedia.org/wiki/Function_prototype + + """ + results = [i.get_datatype('Rust') for i in routine.results] + + if len(results) == 1: + rstype = " -> " + results[0] + elif len(routine.results) > 1: + rstype = " -> (" + ", ".join(results) + ")" + else: + rstype = "" + + type_args = [] + for arg in routine.arguments: + name = self.printer.doprint(arg.name) + if arg.dimensions or isinstance(arg, ResultBase): + type_args.append(("*%s" % name, arg.get_datatype('Rust'))) + else: + type_args.append((name, arg.get_datatype('Rust'))) + arguments = ", ".join([ "%s: %s" % t for t in type_args]) + return "fn %s(%s)%s" % (routine.name, arguments, rstype) + + def _preprocessor_statements(self, prefix): + code_lines = [] + # code_lines.append("use std::f64::consts::*;\n") + return code_lines + + def _get_routine_opening(self, routine): + prototype = self.get_prototype(routine) + return ["%s {\n" % prototype] + + def _declare_arguments(self, routine): + # arguments are declared in prototype + return [] + + def _declare_globals(self, routine): + # global variables are not explicitly declared within C functions + return [] + + def _declare_locals(self, routine): + # loop variables are declared in loop statement + return [] + + def _call_printer(self, routine): + + code_lines = [] + declarations = [] + returns = [] + + # Compose a list of symbols to be dereferenced in the function + # body. These are the arguments that were passed by a reference + # pointer, excluding arrays. + dereference = [] + for arg in routine.arguments: + if isinstance(arg, ResultBase) and not arg.dimensions: + dereference.append(arg.name) + + for result in routine.results: + if isinstance(result, Result): + assign_to = result.result_var + returns.append(str(result.result_var)) + else: + raise CodeGenError("unexpected object in Routine results") + + constants, not_supported, rs_expr = self._printer_method_with_settings( + 'doprint', {"human": False, "strict": False}, result.expr, assign_to=assign_to) + + for name, value in sorted(constants, key=str): + declarations.append("const %s: f64 = %s;\n" % (name, value)) + + for obj in sorted(not_supported, key=str): + if isinstance(obj, Function): + name = obj.func + else: + name = obj + declarations.append("// unsupported: %s\n" % (name)) + + code_lines.append("let %s\n" % rs_expr) + + if len(returns) > 1: + returns = ['(' + ', '.join(returns) + ')'] + + returns.append('\n') + + return declarations + code_lines + returns + + def _get_routine_ending(self, routine): + return ["}\n"] + + def dump_rs(self, routines, f, prefix, header=True, empty=True): + self.dump_code(routines, f, prefix, header, empty) + + dump_rs.extension = code_extension # type: ignore + dump_rs.__doc__ = CodeGen.dump_code.__doc__ + + # This list of dump functions is used by CodeGen.write to know which dump + # functions it has to call. + dump_fns = [dump_rs] + + + + +def get_code_generator(language, project=None, standard=None, printer = None): + if language == 'C': + if standard is None: + pass + elif standard.lower() == 'c89': + language = 'C89' + elif standard.lower() == 'c99': + language = 'C99' + CodeGenClass = {"C": CCodeGen, "C89": C89CodeGen, "C99": C99CodeGen, + "F95": FCodeGen, "JULIA": JuliaCodeGen, + "OCTAVE": OctaveCodeGen, + "RUST": RustCodeGen}.get(language.upper()) + if CodeGenClass is None: + raise ValueError("Language '%s' is not supported." % language) + return CodeGenClass(project, printer) + + +# +# Friendly functions +# + + +def codegen(name_expr, language=None, prefix=None, project="project", + to_files=False, header=True, empty=True, argument_sequence=None, + global_vars=None, standard=None, code_gen=None, printer=None): + """Generate source code for expressions in a given language. + + Parameters + ========== + + name_expr : tuple, or list of tuples + A single (name, expression) tuple or a list of (name, expression) + tuples. Each tuple corresponds to a routine. If the expression is + an equality (an instance of class Equality) the left hand side is + considered an output argument. If expression is an iterable, then + the routine will have multiple outputs. + + language : string, + A string that indicates the source code language. This is case + insensitive. Currently, 'C', 'F95' and 'Octave' are supported. + 'Octave' generates code compatible with both Octave and Matlab. + + prefix : string, optional + A prefix for the names of the files that contain the source code. + Language-dependent suffixes will be appended. If omitted, the name + of the first name_expr tuple is used. + + project : string, optional + A project name, used for making unique preprocessor instructions. + [default: "project"] + + to_files : bool, optional + When True, the code will be written to one or more files with the + given prefix, otherwise strings with the names and contents of + these files are returned. [default: False] + + header : bool, optional + When True, a header is written on top of each source file. + [default: True] + + empty : bool, optional + When True, empty lines are used to structure the code. + [default: True] + + argument_sequence : iterable, optional + Sequence of arguments for the routine in a preferred order. A + CodeGenError is raised if required arguments are missing. + Redundant arguments are used without warning. If omitted, + arguments will be ordered alphabetically, but with all input + arguments first, and then output or in-out arguments. + + global_vars : iterable, optional + Sequence of global variables used by the routine. Variables + listed here will not show up as function arguments. + + standard : string, optional + + code_gen : CodeGen instance, optional + An instance of a CodeGen subclass. Overrides ``language``. + + printer : Printer instance, optional + An instance of a Printer subclass. + + Examples + ======== + + >>> from sympy.utilities.codegen import codegen + >>> from sympy.abc import x, y, z + >>> [(c_name, c_code), (h_name, c_header)] = codegen( + ... ("f", x+y*z), "C89", "test", header=False, empty=False) + >>> print(c_name) + test.c + >>> print(c_code) + #include "test.h" + #include + double f(double x, double y, double z) { + double f_result; + f_result = x + y*z; + return f_result; + } + + >>> print(h_name) + test.h + >>> print(c_header) + #ifndef PROJECT__TEST__H + #define PROJECT__TEST__H + double f(double x, double y, double z); + #endif + + + Another example using Equality objects to give named outputs. Here the + filename (prefix) is taken from the first (name, expr) pair. + + >>> from sympy.abc import f, g + >>> from sympy import Eq + >>> [(c_name, c_code), (h_name, c_header)] = codegen( + ... [("myfcn", x + y), ("fcn2", [Eq(f, 2*x), Eq(g, y)])], + ... "C99", header=False, empty=False) + >>> print(c_name) + myfcn.c + >>> print(c_code) + #include "myfcn.h" + #include + double myfcn(double x, double y) { + double myfcn_result; + myfcn_result = x + y; + return myfcn_result; + } + void fcn2(double x, double y, double *f, double *g) { + (*f) = 2*x; + (*g) = y; + } + + + If the generated function(s) will be part of a larger project where various + global variables have been defined, the 'global_vars' option can be used + to remove the specified variables from the function signature + + >>> from sympy.utilities.codegen import codegen + >>> from sympy.abc import x, y, z + >>> [(f_name, f_code), header] = codegen( + ... ("f", x+y*z), "F95", header=False, empty=False, + ... argument_sequence=(x, y), global_vars=(z,)) + >>> print(f_code) + REAL*8 function f(x, y) + implicit none + REAL*8, intent(in) :: x + REAL*8, intent(in) :: y + f = x + y*z + end function + + + """ + + # Initialize the code generator. + if language is None: + if code_gen is None: + raise ValueError("Need either language or code_gen") + else: + if code_gen is not None: + raise ValueError("You cannot specify both language and code_gen.") + code_gen = get_code_generator(language, project, standard, printer) + + if isinstance(name_expr[0], str): + # single tuple is given, turn it into a singleton list with a tuple. + name_expr = [name_expr] + + if prefix is None: + prefix = name_expr[0][0] + + # Construct Routines appropriate for this code_gen from (name, expr) pairs. + routines = [] + for name, expr in name_expr: + routines.append(code_gen.routine(name, expr, argument_sequence, + global_vars)) + + # Write the code. + return code_gen.write(routines, prefix, to_files, header, empty) + + +def make_routine(name, expr, argument_sequence=None, + global_vars=None, language="F95"): + """A factory that makes an appropriate Routine from an expression. + + Parameters + ========== + + name : string + The name of this routine in the generated code. + + expr : expression or list/tuple of expressions + A SymPy expression that the Routine instance will represent. If + given a list or tuple of expressions, the routine will be + considered to have multiple return values and/or output arguments. + + argument_sequence : list or tuple, optional + List arguments for the routine in a preferred order. If omitted, + the results are language dependent, for example, alphabetical order + or in the same order as the given expressions. + + global_vars : iterable, optional + Sequence of global variables used by the routine. Variables + listed here will not show up as function arguments. + + language : string, optional + Specify a target language. The Routine itself should be + language-agnostic but the precise way one is created, error + checking, etc depend on the language. [default: "F95"]. + + Notes + ===== + + A decision about whether to use output arguments or return values is made + depending on both the language and the particular mathematical expressions. + For an expression of type Equality, the left hand side is typically made + into an OutputArgument (or perhaps an InOutArgument if appropriate). + Otherwise, typically, the calculated expression is made a return values of + the routine. + + Examples + ======== + + >>> from sympy.utilities.codegen import make_routine + >>> from sympy.abc import x, y, f, g + >>> from sympy import Eq + >>> r = make_routine('test', [Eq(f, 2*x), Eq(g, x + y)]) + >>> [arg.result_var for arg in r.results] + [] + >>> [arg.name for arg in r.arguments] + [x, y, f, g] + >>> [arg.name for arg in r.result_variables] + [f, g] + >>> r.local_vars + set() + + Another more complicated example with a mixture of specified and + automatically-assigned names. Also has Matrix output. + + >>> from sympy import Matrix + >>> r = make_routine('fcn', [x*y, Eq(f, 1), Eq(g, x + g), Matrix([[x, 2]])]) + >>> [arg.result_var for arg in r.results] # doctest: +SKIP + [result_5397460570204848505] + >>> [arg.expr for arg in r.results] + [x*y] + >>> [arg.name for arg in r.arguments] # doctest: +SKIP + [x, y, f, g, out_8598435338387848786] + + We can examine the various arguments more closely: + + >>> from sympy.utilities.codegen import (InputArgument, OutputArgument, + ... InOutArgument) + >>> [a.name for a in r.arguments if isinstance(a, InputArgument)] + [x, y] + + >>> [a.name for a in r.arguments if isinstance(a, OutputArgument)] # doctest: +SKIP + [f, out_8598435338387848786] + >>> [a.expr for a in r.arguments if isinstance(a, OutputArgument)] + [1, Matrix([[x, 2]])] + + >>> [a.name for a in r.arguments if isinstance(a, InOutArgument)] + [g] + >>> [a.expr for a in r.arguments if isinstance(a, InOutArgument)] + [g + x] + + """ + + # initialize a new code generator + code_gen = get_code_generator(language) + + return code_gen.routine(name, expr, argument_sequence, global_vars) diff --git a/.venv/lib/python3.13/site-packages/sympy/utilities/decorator.py b/.venv/lib/python3.13/site-packages/sympy/utilities/decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..82636c35e9f235a1d23aa7df5182418db3ef6b4f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/utilities/decorator.py @@ -0,0 +1,339 @@ +"""Useful utility decorators. """ + +from typing import TypeVar +import sys +import types +import inspect +from functools import wraps, update_wrapper + +from sympy.utilities.exceptions import sympy_deprecation_warning + + +T = TypeVar('T') +"""A generic type""" + + +def threaded_factory(func, use_add): + """A factory for ``threaded`` decorators. """ + from sympy.core import sympify + from sympy.matrices import MatrixBase + from sympy.utilities.iterables import iterable + + @wraps(func) + def threaded_func(expr, *args, **kwargs): + if isinstance(expr, MatrixBase): + return expr.applyfunc(lambda f: func(f, *args, **kwargs)) + elif iterable(expr): + try: + return expr.__class__([func(f, *args, **kwargs) for f in expr]) + except TypeError: + return expr + else: + expr = sympify(expr) + + if use_add and expr.is_Add: + return expr.__class__(*[ func(f, *args, **kwargs) for f in expr.args ]) + elif expr.is_Relational: + return expr.__class__(func(expr.lhs, *args, **kwargs), + func(expr.rhs, *args, **kwargs)) + else: + return func(expr, *args, **kwargs) + + return threaded_func + + +def threaded(func): + """Apply ``func`` to sub--elements of an object, including :class:`~.Add`. + + This decorator is intended to make it uniformly possible to apply a + function to all elements of composite objects, e.g. matrices, lists, tuples + and other iterable containers, or just expressions. + + This version of :func:`threaded` decorator allows threading over + elements of :class:`~.Add` class. If this behavior is not desirable + use :func:`xthreaded` decorator. + + Functions using this decorator must have the following signature:: + + @threaded + def function(expr, *args, **kwargs): + + """ + return threaded_factory(func, True) + + +def xthreaded(func): + """Apply ``func`` to sub--elements of an object, excluding :class:`~.Add`. + + This decorator is intended to make it uniformly possible to apply a + function to all elements of composite objects, e.g. matrices, lists, tuples + and other iterable containers, or just expressions. + + This version of :func:`threaded` decorator disallows threading over + elements of :class:`~.Add` class. If this behavior is not desirable + use :func:`threaded` decorator. + + Functions using this decorator must have the following signature:: + + @xthreaded + def function(expr, *args, **kwargs): + + """ + return threaded_factory(func, False) + + +def conserve_mpmath_dps(func): + """After the function finishes, resets the value of ``mpmath.mp.dps`` to + the value it had before the function was run.""" + import mpmath + + def func_wrapper(*args, **kwargs): + dps = mpmath.mp.dps + try: + return func(*args, **kwargs) + finally: + mpmath.mp.dps = dps + + func_wrapper = update_wrapper(func_wrapper, func) + return func_wrapper + + +class no_attrs_in_subclass: + """Don't 'inherit' certain attributes from a base class + + >>> from sympy.utilities.decorator import no_attrs_in_subclass + + >>> class A(object): + ... x = 'test' + + >>> A.x = no_attrs_in_subclass(A, A.x) + + >>> class B(A): + ... pass + + >>> hasattr(A, 'x') + True + >>> hasattr(B, 'x') + False + + """ + def __init__(self, cls, f): + self.cls = cls + self.f = f + + def __get__(self, instance, owner=None): + if owner == self.cls: + if hasattr(self.f, '__get__'): + return self.f.__get__(instance, owner) + return self.f + raise AttributeError + + +def doctest_depends_on(exe=None, modules=None, disable_viewers=None, + python_version=None, ground_types=None): + """ + Adds metadata about the dependencies which need to be met for doctesting + the docstrings of the decorated objects. + + ``exe`` should be a list of executables + + ``modules`` should be a list of modules + + ``disable_viewers`` should be a list of viewers for :func:`~sympy.printing.preview.preview` to disable + + ``python_version`` should be the minimum Python version required, as a tuple + (like ``(3, 0)``) + """ + dependencies = {} + if exe is not None: + dependencies['executables'] = exe + if modules is not None: + dependencies['modules'] = modules + if disable_viewers is not None: + dependencies['disable_viewers'] = disable_viewers + if python_version is not None: + dependencies['python_version'] = python_version + if ground_types is not None: + dependencies['ground_types'] = ground_types + + def skiptests(): + from sympy.testing.runtests import DependencyError, SymPyDocTests, PyTestReporter # lazy import + r = PyTestReporter() + t = SymPyDocTests(r, None) + try: + t._check_dependencies(**dependencies) + except DependencyError: + return True # Skip doctests + else: + return False # Run doctests + + def depends_on_deco(fn): + fn._doctest_depends_on = dependencies + fn.__doctest_skip__ = skiptests + + if inspect.isclass(fn): + fn._doctest_depdends_on = no_attrs_in_subclass( + fn, fn._doctest_depends_on) + fn.__doctest_skip__ = no_attrs_in_subclass( + fn, fn.__doctest_skip__) + return fn + + return depends_on_deco + + +def public(obj: T) -> T: + """ + Append ``obj``'s name to global ``__all__`` variable (call site). + + By using this decorator on functions or classes you achieve the same goal + as by filling ``__all__`` variables manually, you just do not have to repeat + yourself (object's name). You also know if object is public at definition + site, not at some random location (where ``__all__`` was set). + + Note that in multiple decorator setup (in almost all cases) ``@public`` + decorator must be applied before any other decorators, because it relies + on the pointer to object's global namespace. If you apply other decorators + first, ``@public`` may end up modifying the wrong namespace. + + Examples + ======== + + >>> from sympy.utilities.decorator import public + + >>> __all__ # noqa: F821 + Traceback (most recent call last): + ... + NameError: name '__all__' is not defined + + >>> @public + ... def some_function(): + ... pass + + >>> __all__ # noqa: F821 + ['some_function'] + + """ + if isinstance(obj, types.FunctionType): + ns = obj.__globals__ + name = obj.__name__ + elif isinstance(obj, (type(type), type)): + ns = sys.modules[obj.__module__].__dict__ + name = obj.__name__ + else: + raise TypeError("expected a function or a class, got %s" % obj) + + if "__all__" not in ns: + ns["__all__"] = [name] + else: + ns["__all__"].append(name) + + return obj + + +def memoize_property(propfunc): + """Property decorator that caches the value of potentially expensive + ``propfunc`` after the first evaluation. The cached value is stored in + the corresponding property name with an attached underscore.""" + attrname = '_' + propfunc.__name__ + sentinel = object() + + @wraps(propfunc) + def accessor(self): + val = getattr(self, attrname, sentinel) + if val is sentinel: + val = propfunc(self) + setattr(self, attrname, val) + return val + + return property(accessor) + + +def deprecated(message, *, deprecated_since_version, + active_deprecations_target, stacklevel=3): + ''' + Mark a function as deprecated. + + This decorator should be used if an entire function or class is + deprecated. If only a certain functionality is deprecated, you should use + :func:`~.warns_deprecated_sympy` directly. This decorator is just a + convenience. There is no functional difference between using this + decorator and calling ``warns_deprecated_sympy()`` at the top of the + function. + + The decorator takes the same arguments as + :func:`~.warns_deprecated_sympy`. See its + documentation for details on what the keywords to this decorator do. + + See the :ref:`deprecation-policy` document for details on when and how + things should be deprecated in SymPy. + + Examples + ======== + + >>> from sympy.utilities.decorator import deprecated + >>> from sympy import simplify + >>> @deprecated("""\ + ... The simplify_this(expr) function is deprecated. Use simplify(expr) + ... instead.""", deprecated_since_version="1.1", + ... active_deprecations_target='simplify-this-deprecation') + ... def simplify_this(expr): + ... """ + ... Simplify ``expr``. + ... + ... .. deprecated:: 1.1 + ... + ... The ``simplify_this`` function is deprecated. Use :func:`simplify` + ... instead. See its documentation for more information. See + ... :ref:`simplify-this-deprecation` for details. + ... + ... """ + ... return simplify(expr) + >>> from sympy.abc import x + >>> simplify_this(x*(x + 1) - x**2) # doctest: +SKIP + :1: SymPyDeprecationWarning: + + The simplify_this(expr) function is deprecated. Use simplify(expr) + instead. + + See https://docs.sympy.org/latest/explanation/active-deprecations.html#simplify-this-deprecation + for details. + + This has been deprecated since SymPy version 1.1. It + will be removed in a future version of SymPy. + + simplify_this(x) + x + + See Also + ======== + sympy.utilities.exceptions.SymPyDeprecationWarning + sympy.utilities.exceptions.sympy_deprecation_warning + sympy.utilities.exceptions.ignore_warnings + sympy.testing.pytest.warns_deprecated_sympy + + ''' + decorator_kwargs = {"deprecated_since_version": deprecated_since_version, + "active_deprecations_target": active_deprecations_target} + def deprecated_decorator(wrapped): + if hasattr(wrapped, '__mro__'): # wrapped is actually a class + class wrapper(wrapped): + __doc__ = wrapped.__doc__ + __module__ = wrapped.__module__ + _sympy_deprecated_func = wrapped + if '__new__' in wrapped.__dict__: + def __new__(cls, *args, **kwargs): + sympy_deprecation_warning(message, **decorator_kwargs, stacklevel=stacklevel) + return super().__new__(cls, *args, **kwargs) + else: + def __init__(self, *args, **kwargs): + sympy_deprecation_warning(message, **decorator_kwargs, stacklevel=stacklevel) + super().__init__(*args, **kwargs) + wrapper.__name__ = wrapped.__name__ + else: + @wraps(wrapped) + def wrapper(*args, **kwargs): + sympy_deprecation_warning(message, **decorator_kwargs, stacklevel=stacklevel) + return wrapped(*args, **kwargs) + wrapper._sympy_deprecated_func = wrapped + return wrapper + return deprecated_decorator diff --git a/.venv/lib/python3.13/site-packages/sympy/utilities/enumerative.py b/.venv/lib/python3.13/site-packages/sympy/utilities/enumerative.py new file mode 100644 index 0000000000000000000000000000000000000000..dcdabf064ad6291282b5905e7aed0ba9eb412d1a --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/utilities/enumerative.py @@ -0,0 +1,1155 @@ +""" +Algorithms and classes to support enumerative combinatorics. + +Currently just multiset partitions, but more could be added. + +Terminology (following Knuth, algorithm 7.1.2.5M TAOCP) +*multiset* aaabbcccc has a *partition* aaabc | bccc + +The submultisets, aaabc and bccc of the partition are called +*parts*, or sometimes *vectors*. (Knuth notes that multiset +partitions can be thought of as partitions of vectors of integers, +where the ith element of the vector gives the multiplicity of +element i.) + +The values a, b and c are *components* of the multiset. These +correspond to elements of a set, but in a multiset can be present +with a multiplicity greater than 1. + +The algorithm deserves some explanation. + +Think of the part aaabc from the multiset above. If we impose an +ordering on the components of the multiset, we can represent a part +with a vector, in which the value of the first element of the vector +corresponds to the multiplicity of the first component in that +part. Thus, aaabc can be represented by the vector [3, 1, 1]. We +can also define an ordering on parts, based on the lexicographic +ordering of the vector (leftmost vector element, i.e., the element +with the smallest component number, is the most significant), so +that [3, 1, 1] > [3, 1, 0] and [3, 1, 1] > [2, 1, 4]. The ordering +on parts can be extended to an ordering on partitions: First, sort +the parts in each partition, left-to-right in decreasing order. Then +partition A is greater than partition B if A's leftmost/greatest +part is greater than B's leftmost part. If the leftmost parts are +equal, compare the second parts, and so on. + +In this ordering, the greatest partition of a given multiset has only +one part. The least partition is the one in which the components +are spread out, one per part. + +The enumeration algorithms in this file yield the partitions of the +argument multiset in decreasing order. The main data structure is a +stack of parts, corresponding to the current partition. An +important invariant is that the parts on the stack are themselves in +decreasing order. This data structure is decremented to find the +next smaller partition. Most often, decrementing the partition will +only involve adjustments to the smallest parts at the top of the +stack, much as adjacent integers *usually* differ only in their last +few digits. + +Knuth's algorithm uses two main operations on parts: + +Decrement - change the part so that it is smaller in the + (vector) lexicographic order, but reduced by the smallest amount possible. + For example, if the multiset has vector [5, + 3, 1], and the bottom/greatest part is [4, 2, 1], this part would + decrement to [4, 2, 0], while [4, 0, 0] would decrement to [3, 3, + 1]. A singleton part is never decremented -- [1, 0, 0] is not + decremented to [0, 3, 1]. Instead, the decrement operator needs + to fail for this case. In Knuth's pseudocode, the decrement + operator is step m5. + +Spread unallocated multiplicity - Once a part has been decremented, + it cannot be the rightmost part in the partition. There is some + multiplicity that has not been allocated, and new parts must be + created above it in the stack to use up this multiplicity. To + maintain the invariant that the parts on the stack are in + decreasing order, these new parts must be less than or equal to + the decremented part. + For example, if the multiset is [5, 3, 1], and its most + significant part has just been decremented to [5, 3, 0], the + spread operation will add a new part so that the stack becomes + [[5, 3, 0], [0, 0, 1]]. If the most significant part (for the + same multiset) has been decremented to [2, 0, 0] the stack becomes + [[2, 0, 0], [2, 0, 0], [1, 3, 1]]. In the pseudocode, the spread + operation for one part is step m2. The complete spread operation + is a loop of steps m2 and m3. + +In order to facilitate the spread operation, Knuth stores, for each +component of each part, not just the multiplicity of that component +in the part, but also the total multiplicity available for this +component in this part or any lesser part above it on the stack. + +One added twist is that Knuth does not represent the part vectors as +arrays. Instead, he uses a sparse representation, in which a +component of a part is represented as a component number (c), plus +the multiplicity of the component in that part (v) as well as the +total multiplicity available for that component (u). This saves +time that would be spent skipping over zeros. + +""" + +class PartComponent: + """Internal class used in support of the multiset partitions + enumerators and the associated visitor functions. + + Represents one component of one part of the current partition. + + A stack of these, plus an auxiliary frame array, f, represents a + partition of the multiset. + + Knuth's pseudocode makes c, u, and v separate arrays. + """ + + __slots__ = ('c', 'u', 'v') + + def __init__(self): + self.c = 0 # Component number + self.u = 0 # The as yet unpartitioned amount in component c + # *before* it is allocated by this triple + self.v = 0 # Amount of c component in the current part + # (v<=u). An invariant of the representation is + # that the next higher triple for this component + # (if there is one) will have a value of u-v in + # its u attribute. + + def __repr__(self): + "for debug/algorithm animation purposes" + return 'c:%d u:%d v:%d' % (self.c, self.u, self.v) + + def __eq__(self, other): + """Define value oriented equality, which is useful for testers""" + return (isinstance(other, self.__class__) and + self.c == other.c and + self.u == other.u and + self.v == other.v) + + def __ne__(self, other): + """Defined for consistency with __eq__""" + return not self == other + + +# This function tries to be a faithful implementation of algorithm +# 7.1.2.5M in Volume 4A, Combinatoral Algorithms, Part 1, of The Art +# of Computer Programming, by Donald Knuth. This includes using +# (mostly) the same variable names, etc. This makes for rather +# low-level Python. + +# Changes from Knuth's pseudocode include +# - use PartComponent struct/object instead of 3 arrays +# - make the function a generator +# - map (with some difficulty) the GOTOs to Python control structures. +# - Knuth uses 1-based numbering for components, this code is 0-based +# - renamed variable l to lpart. +# - flag variable x takes on values True/False instead of 1/0 +# +def multiset_partitions_taocp(multiplicities): + """Enumerates partitions of a multiset. + + Parameters + ========== + + multiplicities + list of integer multiplicities of the components of the multiset. + + Yields + ====== + + state + Internal data structure which encodes a particular partition. + This output is then usually processed by a visitor function + which combines the information from this data structure with + the components themselves to produce an actual partition. + + Unless they wish to create their own visitor function, users will + have little need to look inside this data structure. But, for + reference, it is a 3-element list with components: + + f + is a frame array, which is used to divide pstack into parts. + + lpart + points to the base of the topmost part. + + pstack + is an array of PartComponent objects. + + The ``state`` output offers a peek into the internal data + structures of the enumeration function. The client should + treat this as read-only; any modification of the data + structure will cause unpredictable (and almost certainly + incorrect) results. Also, the components of ``state`` are + modified in place at each iteration. Hence, the visitor must + be called at each loop iteration. Accumulating the ``state`` + instances and processing them later will not work. + + Examples + ======== + + >>> from sympy.utilities.enumerative import list_visitor + >>> from sympy.utilities.enumerative import multiset_partitions_taocp + >>> # variables components and multiplicities represent the multiset 'abb' + >>> components = 'ab' + >>> multiplicities = [1, 2] + >>> states = multiset_partitions_taocp(multiplicities) + >>> list(list_visitor(state, components) for state in states) + [[['a', 'b', 'b']], + [['a', 'b'], ['b']], + [['a'], ['b', 'b']], + [['a'], ['b'], ['b']]] + + See Also + ======== + + sympy.utilities.iterables.multiset_partitions: Takes a multiset + as input and directly yields multiset partitions. It + dispatches to a number of functions, including this one, for + implementation. Most users will find it more convenient to + use than multiset_partitions_taocp. + + """ + + # Important variables. + # m is the number of components, i.e., number of distinct elements + m = len(multiplicities) + # n is the cardinality, total number of elements whether or not distinct + n = sum(multiplicities) + + # The main data structure, f segments pstack into parts. See + # list_visitor() for example code indicating how this internal + # state corresponds to a partition. + + # Note: allocation of space for stack is conservative. Knuth's + # exercise 7.2.1.5.68 gives some indication of how to tighten this + # bound, but this is not implemented. + pstack = [PartComponent() for i in range(n * m + 1)] + f = [0] * (n + 1) + + # Step M1 in Knuth (Initialize) + # Initial state - entire multiset in one part. + for j in range(m): + ps = pstack[j] + ps.c = j + ps.u = multiplicities[j] + ps.v = multiplicities[j] + + # Other variables + f[0] = 0 + a = 0 + lpart = 0 + f[1] = m + b = m # in general, current stack frame is from a to b - 1 + + while True: + while True: + # Step M2 (Subtract v from u) + k = b + x = False + for j in range(a, b): + pstack[k].u = pstack[j].u - pstack[j].v + if pstack[k].u == 0: + x = True + elif not x: + pstack[k].c = pstack[j].c + pstack[k].v = min(pstack[j].v, pstack[k].u) + x = pstack[k].u < pstack[j].v + k = k + 1 + else: # x is True + pstack[k].c = pstack[j].c + pstack[k].v = pstack[k].u + k = k + 1 + # Note: x is True iff v has changed + + # Step M3 (Push if nonzero.) + if k > b: + a = b + b = k + lpart = lpart + 1 + f[lpart + 1] = b + # Return to M2 + else: + break # Continue to M4 + + # M4 Visit a partition + state = [f, lpart, pstack] + yield state + + # M5 (Decrease v) + while True: + j = b-1 + while (pstack[j].v == 0): + j = j - 1 + if j == a and pstack[j].v == 1: + # M6 (Backtrack) + if lpart == 0: + return + lpart = lpart - 1 + b = a + a = f[lpart] + # Return to M5 + else: + pstack[j].v = pstack[j].v - 1 + for k in range(j + 1, b): + pstack[k].v = pstack[k].u + break # GOTO M2 + +# --------------- Visitor functions for multiset partitions --------------- +# A visitor takes the partition state generated by +# multiset_partitions_taocp or other enumerator, and produces useful +# output (such as the actual partition). + + +def factoring_visitor(state, primes): + """Use with multiset_partitions_taocp to enumerate the ways a + number can be expressed as a product of factors. For this usage, + the exponents of the prime factors of a number are arguments to + the partition enumerator, while the corresponding prime factors + are input here. + + Examples + ======== + + To enumerate the factorings of a number we can think of the elements of the + partition as being the prime factors and the multiplicities as being their + exponents. + + >>> from sympy.utilities.enumerative import factoring_visitor + >>> from sympy.utilities.enumerative import multiset_partitions_taocp + >>> from sympy import factorint + >>> primes, multiplicities = zip(*factorint(24).items()) + >>> primes + (2, 3) + >>> multiplicities + (3, 1) + >>> states = multiset_partitions_taocp(multiplicities) + >>> list(factoring_visitor(state, primes) for state in states) + [[24], [8, 3], [12, 2], [4, 6], [4, 2, 3], [6, 2, 2], [2, 2, 2, 3]] + """ + f, lpart, pstack = state + factoring = [] + for i in range(lpart + 1): + factor = 1 + for ps in pstack[f[i]: f[i + 1]]: + if ps.v > 0: + factor *= primes[ps.c] ** ps.v + factoring.append(factor) + return factoring + + +def list_visitor(state, components): + """Return a list of lists to represent the partition. + + Examples + ======== + + >>> from sympy.utilities.enumerative import list_visitor + >>> from sympy.utilities.enumerative import multiset_partitions_taocp + >>> states = multiset_partitions_taocp([1, 2, 1]) + >>> s = next(states) + >>> list_visitor(s, 'abc') # for multiset 'a b b c' + [['a', 'b', 'b', 'c']] + >>> s = next(states) + >>> list_visitor(s, [1, 2, 3]) # for multiset '1 2 2 3 + [[1, 2, 2], [3]] + """ + f, lpart, pstack = state + + partition = [] + for i in range(lpart+1): + part = [] + for ps in pstack[f[i]:f[i+1]]: + if ps.v > 0: + part.extend([components[ps.c]] * ps.v) + partition.append(part) + + return partition + + +class MultisetPartitionTraverser(): + """ + Has methods to ``enumerate`` and ``count`` the partitions of a multiset. + + This implements a refactored and extended version of Knuth's algorithm + 7.1.2.5M [AOCP]_." + + The enumeration methods of this class are generators and return + data structures which can be interpreted by the same visitor + functions used for the output of ``multiset_partitions_taocp``. + + Examples + ======== + + >>> from sympy.utilities.enumerative import MultisetPartitionTraverser + >>> m = MultisetPartitionTraverser() + >>> m.count_partitions([4,4,4,2]) + 127750 + >>> m.count_partitions([3,3,3]) + 686 + + See Also + ======== + + multiset_partitions_taocp + sympy.utilities.iterables.multiset_partitions + + References + ========== + + .. [AOCP] Algorithm 7.1.2.5M in Volume 4A, Combinatoral Algorithms, + Part 1, of The Art of Computer Programming, by Donald Knuth. + + .. [Factorisatio] On a Problem of Oppenheim concerning + "Factorisatio Numerorum" E. R. Canfield, Paul Erdos, Carl + Pomerance, JOURNAL OF NUMBER THEORY, Vol. 17, No. 1. August + 1983. See section 7 for a description of an algorithm + similar to Knuth's. + + .. [Yorgey] Generating Multiset Partitions, Brent Yorgey, The + Monad.Reader, Issue 8, September 2007. + + """ + + def __init__(self): + self.debug = False + # TRACING variables. These are useful for gathering + # statistics on the algorithm itself, but have no particular + # benefit to a user of the code. + self.k1 = 0 + self.k2 = 0 + self.p1 = 0 + self.pstack = None + self.f = None + self.lpart = 0 + self.discarded = 0 + # dp_stack is list of lists of (part_key, start_count) pairs + self.dp_stack = [] + + # dp_map is map part_key-> count, where count represents the + # number of multiset which are descendants of a part with this + # key, **or any of its decrements** + + # Thus, when we find a part in the map, we add its count + # value to the running total, cut off the enumeration, and + # backtrack + + if not hasattr(self, 'dp_map'): + self.dp_map = {} + + def db_trace(self, msg): + """Useful for understanding/debugging the algorithms. Not + generally activated in end-user code.""" + if self.debug: + # XXX: animation_visitor is undefined... Clearly this does not + # work and was not tested. Previous code in comments below. + raise RuntimeError + #letters = 'abcdefghijklmnopqrstuvwxyz' + #state = [self.f, self.lpart, self.pstack] + #print("DBG:", msg, + # ["".join(part) for part in list_visitor(state, letters)], + # animation_visitor(state)) + + # + # Helper methods for enumeration + # + def _initialize_enumeration(self, multiplicities): + """Allocates and initializes the partition stack. + + This is called from the enumeration/counting routines, so + there is no need to call it separately.""" + + num_components = len(multiplicities) + # cardinality is the total number of elements, whether or not distinct + cardinality = sum(multiplicities) + + # pstack is the partition stack, which is segmented by + # f into parts. + self.pstack = [PartComponent() for i in + range(num_components * cardinality + 1)] + self.f = [0] * (cardinality + 1) + + # Initial state - entire multiset in one part. + for j in range(num_components): + ps = self.pstack[j] + ps.c = j + ps.u = multiplicities[j] + ps.v = multiplicities[j] + + self.f[0] = 0 + self.f[1] = num_components + self.lpart = 0 + + # The decrement_part() method corresponds to step M5 in Knuth's + # algorithm. This is the base version for enum_all(). Modified + # versions of this method are needed if we want to restrict + # sizes of the partitions produced. + def decrement_part(self, part): + """Decrements part (a subrange of pstack), if possible, returning + True iff the part was successfully decremented. + + If you think of the v values in the part as a multi-digit + integer (least significant digit on the right) this is + basically decrementing that integer, but with the extra + constraint that the leftmost digit cannot be decremented to 0. + + Parameters + ========== + + part + The part, represented as a list of PartComponent objects, + which is to be decremented. + + """ + plen = len(part) + for j in range(plen - 1, -1, -1): + if j == 0 and part[j].v > 1 or j > 0 and part[j].v > 0: + # found val to decrement + part[j].v -= 1 + # Reset trailing parts back to maximum + for k in range(j + 1, plen): + part[k].v = part[k].u + return True + return False + + # Version to allow number of parts to be bounded from above. + # Corresponds to (a modified) step M5. + def decrement_part_small(self, part, ub): + """Decrements part (a subrange of pstack), if possible, returning + True iff the part was successfully decremented. + + Parameters + ========== + + part + part to be decremented (topmost part on the stack) + + ub + the maximum number of parts allowed in a partition + returned by the calling traversal. + + Notes + ===== + + The goal of this modification of the ordinary decrement method + is to fail (meaning that the subtree rooted at this part is to + be skipped) when it can be proved that this part can only have + child partitions which are larger than allowed by ``ub``. If a + decision is made to fail, it must be accurate, otherwise the + enumeration will miss some partitions. But, it is OK not to + capture all the possible failures -- if a part is passed that + should not be, the resulting too-large partitions are filtered + by the enumeration one level up. However, as is usual in + constrained enumerations, failing early is advantageous. + + The tests used by this method catch the most common cases, + although this implementation is by no means the last word on + this problem. The tests include: + + 1) ``lpart`` must be less than ``ub`` by at least 2. This is because + once a part has been decremented, the partition + will gain at least one child in the spread step. + + 2) If the leading component of the part is about to be + decremented, check for how many parts will be added in + order to use up the unallocated multiplicity in that + leading component, and fail if this number is greater than + allowed by ``ub``. (See code for the exact expression.) This + test is given in the answer to Knuth's problem 7.2.1.5.69. + + 3) If there is *exactly* enough room to expand the leading + component by the above test, check the next component (if + it exists) once decrementing has finished. If this has + ``v == 0``, this next component will push the expansion over the + limit by 1, so fail. + """ + if self.lpart >= ub - 1: + self.p1 += 1 # increment to keep track of usefulness of tests + return False + plen = len(part) + for j in range(plen - 1, -1, -1): + # Knuth's mod, (answer to problem 7.2.1.5.69) + if j == 0 and (part[0].v - 1)*(ub - self.lpart) < part[0].u: + self.k1 += 1 + return False + + if j == 0 and part[j].v > 1 or j > 0 and part[j].v > 0: + # found val to decrement + part[j].v -= 1 + # Reset trailing parts back to maximum + for k in range(j + 1, plen): + part[k].v = part[k].u + + # Have now decremented part, but are we doomed to + # failure when it is expanded? Check one oddball case + # that turns out to be surprisingly common - exactly + # enough room to expand the leading component, but no + # room for the second component, which has v=0. + if (plen > 1 and part[1].v == 0 and + (part[0].u - part[0].v) == + ((ub - self.lpart - 1) * part[0].v)): + self.k2 += 1 + self.db_trace("Decrement fails test 3") + return False + return True + return False + + def decrement_part_large(self, part, amt, lb): + """Decrements part, while respecting size constraint. + + A part can have no children which are of sufficient size (as + indicated by ``lb``) unless that part has sufficient + unallocated multiplicity. When enforcing the size constraint, + this method will decrement the part (if necessary) by an + amount needed to ensure sufficient unallocated multiplicity. + + Returns True iff the part was successfully decremented. + + Parameters + ========== + + part + part to be decremented (topmost part on the stack) + + amt + Can only take values 0 or 1. A value of 1 means that the + part must be decremented, and then the size constraint is + enforced. A value of 0 means just to enforce the ``lb`` + size constraint. + + lb + The partitions produced by the calling enumeration must + have more parts than this value. + + """ + + if amt == 1: + # In this case we always need to decrement, *before* + # enforcing the "sufficient unallocated multiplicity" + # constraint. Easiest for this is just to call the + # regular decrement method. + if not self.decrement_part(part): + return False + + # Next, perform any needed additional decrementing to respect + # "sufficient unallocated multiplicity" (or fail if this is + # not possible). + min_unalloc = lb - self.lpart + if min_unalloc <= 0: + return True + total_mult = sum(pc.u for pc in part) + total_alloc = sum(pc.v for pc in part) + if total_mult <= min_unalloc: + return False + + deficit = min_unalloc - (total_mult - total_alloc) + if deficit <= 0: + return True + + for i in range(len(part) - 1, -1, -1): + if i == 0: + if part[0].v > deficit: + part[0].v -= deficit + return True + else: + return False # This shouldn't happen, due to above check + else: + if part[i].v >= deficit: + part[i].v -= deficit + return True + else: + deficit -= part[i].v + part[i].v = 0 + + def decrement_part_range(self, part, lb, ub): + """Decrements part (a subrange of pstack), if possible, returning + True iff the part was successfully decremented. + + Parameters + ========== + + part + part to be decremented (topmost part on the stack) + + ub + the maximum number of parts allowed in a partition + returned by the calling traversal. + + lb + The partitions produced by the calling enumeration must + have more parts than this value. + + Notes + ===== + + Combines the constraints of _small and _large decrement + methods. If returns success, part has been decremented at + least once, but perhaps by quite a bit more if needed to meet + the lb constraint. + """ + + # Constraint in the range case is just enforcing both the + # constraints from _small and _large cases. Note the 0 as the + # second argument to the _large call -- this is the signal to + # decrement only as needed to for constraint enforcement. The + # short circuiting and left-to-right order of the 'and' + # operator is important for this to work correctly. + return self.decrement_part_small(part, ub) and \ + self.decrement_part_large(part, 0, lb) + + def spread_part_multiplicity(self): + """Returns True if a new part has been created, and + adjusts pstack, f and lpart as needed. + + Notes + ===== + + Spreads unallocated multiplicity from the current top part + into a new part created above the current on the stack. This + new part is constrained to be less than or equal to the old in + terms of the part ordering. + + This call does nothing (and returns False) if the current top + part has no unallocated multiplicity. + + """ + j = self.f[self.lpart] # base of current top part + k = self.f[self.lpart + 1] # ub of current; potential base of next + base = k # save for later comparison + + changed = False # Set to true when the new part (so far) is + # strictly less than (as opposed to less than + # or equal) to the old. + for j in range(self.f[self.lpart], self.f[self.lpart + 1]): + self.pstack[k].u = self.pstack[j].u - self.pstack[j].v + if self.pstack[k].u == 0: + changed = True + else: + self.pstack[k].c = self.pstack[j].c + if changed: # Put all available multiplicity in this part + self.pstack[k].v = self.pstack[k].u + else: # Still maintaining ordering constraint + if self.pstack[k].u < self.pstack[j].v: + self.pstack[k].v = self.pstack[k].u + changed = True + else: + self.pstack[k].v = self.pstack[j].v + k = k + 1 + if k > base: + # Adjust for the new part on stack + self.lpart = self.lpart + 1 + self.f[self.lpart + 1] = k + return True + return False + + def top_part(self): + """Return current top part on the stack, as a slice of pstack. + + """ + return self.pstack[self.f[self.lpart]:self.f[self.lpart + 1]] + + # Same interface and functionality as multiset_partitions_taocp(), + # but some might find this refactored version easier to follow. + def enum_all(self, multiplicities): + """Enumerate the partitions of a multiset. + + Examples + ======== + + >>> from sympy.utilities.enumerative import list_visitor + >>> from sympy.utilities.enumerative import MultisetPartitionTraverser + >>> m = MultisetPartitionTraverser() + >>> states = m.enum_all([2,2]) + >>> list(list_visitor(state, 'ab') for state in states) + [[['a', 'a', 'b', 'b']], + [['a', 'a', 'b'], ['b']], + [['a', 'a'], ['b', 'b']], + [['a', 'a'], ['b'], ['b']], + [['a', 'b', 'b'], ['a']], + [['a', 'b'], ['a', 'b']], + [['a', 'b'], ['a'], ['b']], + [['a'], ['a'], ['b', 'b']], + [['a'], ['a'], ['b'], ['b']]] + + See Also + ======== + + multiset_partitions_taocp: + which provides the same result as this method, but is + about twice as fast. Hence, enum_all is primarily useful + for testing. Also see the function for a discussion of + states and visitors. + + """ + self._initialize_enumeration(multiplicities) + while True: + while self.spread_part_multiplicity(): + pass + + # M4 Visit a partition + state = [self.f, self.lpart, self.pstack] + yield state + + # M5 (Decrease v) + while not self.decrement_part(self.top_part()): + # M6 (Backtrack) + if self.lpart == 0: + return + self.lpart -= 1 + + def enum_small(self, multiplicities, ub): + """Enumerate multiset partitions with no more than ``ub`` parts. + + Equivalent to enum_range(multiplicities, 0, ub) + + Parameters + ========== + + multiplicities + list of multiplicities of the components of the multiset. + + ub + Maximum number of parts + + Examples + ======== + + >>> from sympy.utilities.enumerative import list_visitor + >>> from sympy.utilities.enumerative import MultisetPartitionTraverser + >>> m = MultisetPartitionTraverser() + >>> states = m.enum_small([2,2], 2) + >>> list(list_visitor(state, 'ab') for state in states) + [[['a', 'a', 'b', 'b']], + [['a', 'a', 'b'], ['b']], + [['a', 'a'], ['b', 'b']], + [['a', 'b', 'b'], ['a']], + [['a', 'b'], ['a', 'b']]] + + The implementation is based, in part, on the answer given to + exercise 69, in Knuth [AOCP]_. + + See Also + ======== + + enum_all, enum_large, enum_range + + """ + + # Keep track of iterations which do not yield a partition. + # Clearly, we would like to keep this number small. + self.discarded = 0 + if ub <= 0: + return + self._initialize_enumeration(multiplicities) + while True: + while self.spread_part_multiplicity(): + self.db_trace('spread 1') + if self.lpart >= ub: + self.discarded += 1 + self.db_trace(' Discarding') + self.lpart = ub - 2 + break + else: + # M4 Visit a partition + state = [self.f, self.lpart, self.pstack] + yield state + + # M5 (Decrease v) + while not self.decrement_part_small(self.top_part(), ub): + self.db_trace("Failed decrement, going to backtrack") + # M6 (Backtrack) + if self.lpart == 0: + return + self.lpart -= 1 + self.db_trace("Backtracked to") + self.db_trace("decrement ok, about to expand") + + def enum_large(self, multiplicities, lb): + """Enumerate the partitions of a multiset with lb < num(parts) + + Equivalent to enum_range(multiplicities, lb, sum(multiplicities)) + + Parameters + ========== + + multiplicities + list of multiplicities of the components of the multiset. + + lb + Number of parts in the partition must be greater than + this lower bound. + + + Examples + ======== + + >>> from sympy.utilities.enumerative import list_visitor + >>> from sympy.utilities.enumerative import MultisetPartitionTraverser + >>> m = MultisetPartitionTraverser() + >>> states = m.enum_large([2,2], 2) + >>> list(list_visitor(state, 'ab') for state in states) + [[['a', 'a'], ['b'], ['b']], + [['a', 'b'], ['a'], ['b']], + [['a'], ['a'], ['b', 'b']], + [['a'], ['a'], ['b'], ['b']]] + + See Also + ======== + + enum_all, enum_small, enum_range + + """ + self.discarded = 0 + if lb >= sum(multiplicities): + return + self._initialize_enumeration(multiplicities) + self.decrement_part_large(self.top_part(), 0, lb) + while True: + good_partition = True + while self.spread_part_multiplicity(): + if not self.decrement_part_large(self.top_part(), 0, lb): + # Failure here should be rare/impossible + self.discarded += 1 + good_partition = False + break + + # M4 Visit a partition + if good_partition: + state = [self.f, self.lpart, self.pstack] + yield state + + # M5 (Decrease v) + while not self.decrement_part_large(self.top_part(), 1, lb): + # M6 (Backtrack) + if self.lpart == 0: + return + self.lpart -= 1 + + def enum_range(self, multiplicities, lb, ub): + + """Enumerate the partitions of a multiset with + ``lb < num(parts) <= ub``. + + In particular, if partitions with exactly ``k`` parts are + desired, call with ``(multiplicities, k - 1, k)``. This + method generalizes enum_all, enum_small, and enum_large. + + Examples + ======== + + >>> from sympy.utilities.enumerative import list_visitor + >>> from sympy.utilities.enumerative import MultisetPartitionTraverser + >>> m = MultisetPartitionTraverser() + >>> states = m.enum_range([2,2], 1, 2) + >>> list(list_visitor(state, 'ab') for state in states) + [[['a', 'a', 'b'], ['b']], + [['a', 'a'], ['b', 'b']], + [['a', 'b', 'b'], ['a']], + [['a', 'b'], ['a', 'b']]] + + """ + # combine the constraints of the _large and _small + # enumerations. + self.discarded = 0 + if ub <= 0 or lb >= sum(multiplicities): + return + self._initialize_enumeration(multiplicities) + self.decrement_part_large(self.top_part(), 0, lb) + while True: + good_partition = True + while self.spread_part_multiplicity(): + self.db_trace("spread 1") + if not self.decrement_part_large(self.top_part(), 0, lb): + # Failure here - possible in range case? + self.db_trace(" Discarding (large cons)") + self.discarded += 1 + good_partition = False + break + elif self.lpart >= ub: + self.discarded += 1 + good_partition = False + self.db_trace(" Discarding small cons") + self.lpart = ub - 2 + break + + # M4 Visit a partition + if good_partition: + state = [self.f, self.lpart, self.pstack] + yield state + + # M5 (Decrease v) + while not self.decrement_part_range(self.top_part(), lb, ub): + self.db_trace("Failed decrement, going to backtrack") + # M6 (Backtrack) + if self.lpart == 0: + return + self.lpart -= 1 + self.db_trace("Backtracked to") + self.db_trace("decrement ok, about to expand") + + def count_partitions_slow(self, multiplicities): + """Returns the number of partitions of a multiset whose elements + have the multiplicities given in ``multiplicities``. + + Primarily for comparison purposes. It follows the same path as + enumerate, and counts, rather than generates, the partitions. + + See Also + ======== + + count_partitions + Has the same calling interface, but is much faster. + + """ + # number of partitions so far in the enumeration + self.pcount = 0 + self._initialize_enumeration(multiplicities) + while True: + while self.spread_part_multiplicity(): + pass + + # M4 Visit (count) a partition + self.pcount += 1 + + # M5 (Decrease v) + while not self.decrement_part(self.top_part()): + # M6 (Backtrack) + if self.lpart == 0: + return self.pcount + self.lpart -= 1 + + def count_partitions(self, multiplicities): + """Returns the number of partitions of a multiset whose components + have the multiplicities given in ``multiplicities``. + + For larger counts, this method is much faster than calling one + of the enumerators and counting the result. Uses dynamic + programming to cut down on the number of nodes actually + explored. The dictionary used in order to accelerate the + counting process is stored in the ``MultisetPartitionTraverser`` + object and persists across calls. If the user does not + expect to call ``count_partitions`` for any additional + multisets, the object should be cleared to save memory. On + the other hand, the cache built up from one count run can + significantly speed up subsequent calls to ``count_partitions``, + so it may be advantageous not to clear the object. + + Examples + ======== + + >>> from sympy.utilities.enumerative import MultisetPartitionTraverser + >>> m = MultisetPartitionTraverser() + >>> m.count_partitions([9,8,2]) + 288716 + >>> m.count_partitions([2,2]) + 9 + >>> del m + + Notes + ===== + + If one looks at the workings of Knuth's algorithm M [AOCP]_, it + can be viewed as a traversal of a binary tree of parts. A + part has (up to) two children, the left child resulting from + the spread operation, and the right child from the decrement + operation. The ordinary enumeration of multiset partitions is + an in-order traversal of this tree, and with the partitions + corresponding to paths from the root to the leaves. The + mapping from paths to partitions is a little complicated, + since the partition would contain only those parts which are + leaves or the parents of a spread link, not those which are + parents of a decrement link. + + For counting purposes, it is sufficient to count leaves, and + this can be done with a recursive in-order traversal. The + number of leaves of a subtree rooted at a particular part is a + function only of that part itself, so memoizing has the + potential to speed up the counting dramatically. + + This method follows a computational approach which is similar + to the hypothetical memoized recursive function, but with two + differences: + + 1) This method is iterative, borrowing its structure from the + other enumerations and maintaining an explicit stack of + parts which are in the process of being counted. (There + may be multisets which can be counted reasonably quickly by + this implementation, but which would overflow the default + Python recursion limit with a recursive implementation.) + + 2) Instead of using the part data structure directly, a more + compact key is constructed. This saves space, but more + importantly coalesces some parts which would remain + separate with physical keys. + + Unlike the enumeration functions, there is currently no _range + version of count_partitions. If someone wants to stretch + their brain, it should be possible to construct one by + memoizing with a histogram of counts rather than a single + count, and combining the histograms. + """ + # number of partitions so far in the enumeration + self.pcount = 0 + + # dp_stack is list of lists of (part_key, start_count) pairs + self.dp_stack = [] + + self._initialize_enumeration(multiplicities) + pkey = part_key(self.top_part()) + self.dp_stack.append([(pkey, 0), ]) + while True: + while self.spread_part_multiplicity(): + pkey = part_key(self.top_part()) + if pkey in self.dp_map: + # Already have a cached value for the count of the + # subtree rooted at this part. Add it to the + # running counter, and break out of the spread + # loop. The -1 below is to compensate for the + # leaf that this code path would otherwise find, + # and which gets incremented for below. + + self.pcount += (self.dp_map[pkey] - 1) + self.lpart -= 1 + break + else: + self.dp_stack.append([(pkey, self.pcount), ]) + + # M4 count a leaf partition + self.pcount += 1 + + # M5 (Decrease v) + while not self.decrement_part(self.top_part()): + # M6 (Backtrack) + for key, oldcount in self.dp_stack.pop(): + self.dp_map[key] = self.pcount - oldcount + if self.lpart == 0: + return self.pcount + self.lpart -= 1 + + # At this point have successfully decremented the part on + # the stack and it does not appear in the cache. It needs + # to be added to the list at the top of dp_stack + pkey = part_key(self.top_part()) + self.dp_stack[-1].append((pkey, self.pcount),) + + +def part_key(part): + """Helper for MultisetPartitionTraverser.count_partitions that + creates a key for ``part``, that only includes information which can + affect the count for that part. (Any irrelevant information just + reduces the effectiveness of dynamic programming.) + + Notes + ===== + + This member function is a candidate for future exploration. There + are likely symmetries that can be exploited to coalesce some + ``part_key`` values, and thereby save space and improve + performance. + + """ + # The component number is irrelevant for counting partitions, so + # leave it out of the memo key. + rval = [] + for ps in part: + rval.append(ps.u) + rval.append(ps.v) + return tuple(rval) diff --git a/.venv/lib/python3.13/site-packages/sympy/utilities/exceptions.py b/.venv/lib/python3.13/site-packages/sympy/utilities/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..f764a31371ed109d69102777bd9ec0dfb3d75380 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/utilities/exceptions.py @@ -0,0 +1,271 @@ +""" +General SymPy exceptions and warnings. +""" + +import warnings +import contextlib + +from textwrap import dedent + + +class SymPyDeprecationWarning(DeprecationWarning): + r""" + A warning for deprecated features of SymPy. + + See the :ref:`deprecation-policy` document for details on when and how + things should be deprecated in SymPy. + + Note that simply constructing this class will not cause a warning to be + issued. To do that, you must call the :func`sympy_deprecation_warning` + function. For this reason, it is not recommended to ever construct this + class directly. + + Explanation + =========== + + The ``SymPyDeprecationWarning`` class is a subclass of + ``DeprecationWarning`` that is used for all deprecations in SymPy. A + special subclass is used so that we can automatically augment the warning + message with additional metadata about the version the deprecation was + introduced in and a link to the documentation. This also allows users to + explicitly filter deprecation warnings from SymPy using ``warnings`` + filters (see :ref:`silencing-sympy-deprecation-warnings`). + + Additionally, ``SymPyDeprecationWarning`` is enabled to be shown by + default, unlike normal ``DeprecationWarning``\s, which are only shown by + default in interactive sessions. This ensures that deprecation warnings in + SymPy will actually be seen by users. + + See the documentation of :func:`sympy_deprecation_warning` for a + description of the parameters to this function. + + To mark a function as deprecated, you can use the :func:`@deprecated + ` decorator. + + See Also + ======== + sympy.utilities.exceptions.sympy_deprecation_warning + sympy.utilities.exceptions.ignore_warnings + sympy.utilities.decorator.deprecated + sympy.testing.pytest.warns_deprecated_sympy + + """ + def __init__(self, message, *, deprecated_since_version, active_deprecations_target): + + super().__init__(message, deprecated_since_version, + active_deprecations_target) + self.message = message + if not isinstance(deprecated_since_version, str): + raise TypeError(f"'deprecated_since_version' should be a string, got {deprecated_since_version!r}") + self.deprecated_since_version = deprecated_since_version + self.active_deprecations_target = active_deprecations_target + if any(i in active_deprecations_target for i in '()='): + raise ValueError("active_deprecations_target be the part inside of the '(...)='") + + self.full_message = f""" + +{dedent(message).strip()} + +See https://docs.sympy.org/latest/explanation/active-deprecations.html#{active_deprecations_target} +for details. + +This has been deprecated since SymPy version {deprecated_since_version}. It +will be removed in a future version of SymPy. +""" + + def __str__(self): + return self.full_message + + def __repr__(self): + return f"{self.__class__.__name__}({self.message!r}, deprecated_since_version={self.deprecated_since_version!r}, active_deprecations_target={self.active_deprecations_target!r})" + + def __eq__(self, other): + return isinstance(other, SymPyDeprecationWarning) and self.args == other.args + + # Make pickling work. The by default, it tries to recreate the expression + # from its args, but this doesn't work because of our keyword-only + # arguments. + @classmethod + def _new(cls, message, deprecated_since_version, + active_deprecations_target): + return cls(message, deprecated_since_version=deprecated_since_version, active_deprecations_target=active_deprecations_target) + + def __reduce__(self): + return (self._new, (self.message, self.deprecated_since_version, self.active_deprecations_target)) + +# Python by default hides DeprecationWarnings, which we do not want. +warnings.simplefilter("once", SymPyDeprecationWarning) + +def sympy_deprecation_warning(message, *, deprecated_since_version, + active_deprecations_target, stacklevel=3): + r''' + Warn that a feature is deprecated in SymPy. + + See the :ref:`deprecation-policy` document for details on when and how + things should be deprecated in SymPy. + + To mark an entire function or class as deprecated, you can use the + :func:`@deprecated ` decorator. + + Parameters + ========== + + message : str + The deprecation message. This may span multiple lines and contain + code examples. Messages should be wrapped to 80 characters. The + message is automatically dedented and leading and trailing whitespace + stripped. Messages may include dynamic content based on the user + input, but avoid using ``str(expression)`` if an expression can be + arbitrary, as it might be huge and make the warning message + unreadable. + + deprecated_since_version : str + The version of SymPy the feature has been deprecated since. For new + deprecations, this should be the version in `sympy/release.py + `_ + without the ``.dev``. If the next SymPy version ends up being + different from this, the release manager will need to update any + ``SymPyDeprecationWarning``\s using the incorrect version. This + argument is required and must be passed as a keyword argument. + (example: ``deprecated_since_version="1.10"``). + + active_deprecations_target : str + The Sphinx target corresponding to the section for the deprecation in + the :ref:`active-deprecations` document (see + ``doc/src/explanation/active-deprecations.md``). This is used to + automatically generate a URL to the page in the warning message. This + argument is required and must be passed as a keyword argument. + (example: ``active_deprecations_target="deprecated-feature-abc"``) + + stacklevel : int, default: 3 + The ``stacklevel`` parameter that is passed to ``warnings.warn``. If + you create a wrapper that calls this function, this should be + increased so that the warning message shows the user line of code that + produced the warning. Note that in some cases there will be multiple + possible different user code paths that could result in the warning. + In that case, just choose the smallest common stacklevel. + + Examples + ======== + + >>> from sympy.utilities.exceptions import sympy_deprecation_warning + >>> def is_this_zero(x, y=0): + ... """ + ... Determine if x = 0. + ... + ... Parameters + ... ========== + ... + ... x : Expr + ... The expression to check. + ... + ... y : Expr, optional + ... If provided, check if x = y. + ... + ... .. deprecated:: 1.1 + ... + ... The ``y`` argument to ``is_this_zero`` is deprecated. Use + ... ``is_this_zero(x - y)`` instead. + ... + ... """ + ... from sympy import simplify + ... + ... if y != 0: + ... sympy_deprecation_warning(""" + ... The y argument to is_zero() is deprecated. Use is_zero(x - y) instead.""", + ... deprecated_since_version="1.1", + ... active_deprecations_target='is-this-zero-y-deprecation') + ... return simplify(x - y) == 0 + >>> is_this_zero(0) + True + >>> is_this_zero(1, 1) # doctest: +SKIP + :1: SymPyDeprecationWarning: + + The y argument to is_zero() is deprecated. Use is_zero(x - y) instead. + + See https://docs.sympy.org/latest/explanation/active-deprecations.html#is-this-zero-y-deprecation + for details. + + This has been deprecated since SymPy version 1.1. It + will be removed in a future version of SymPy. + + is_this_zero(1, 1) + True + + See Also + ======== + + sympy.utilities.exceptions.SymPyDeprecationWarning + sympy.utilities.exceptions.ignore_warnings + sympy.utilities.decorator.deprecated + sympy.testing.pytest.warns_deprecated_sympy + + ''' + w = SymPyDeprecationWarning(message, + deprecated_since_version=deprecated_since_version, + active_deprecations_target=active_deprecations_target) + warnings.warn(w, stacklevel=stacklevel) + + +@contextlib.contextmanager +def ignore_warnings(warningcls): + ''' + Context manager to suppress warnings during tests. + + .. note:: + + Do not use this with SymPyDeprecationWarning in the tests. + warns_deprecated_sympy() should be used instead. + + This function is useful for suppressing warnings during tests. The warns + function should be used to assert that a warning is raised. The + ignore_warnings function is useful in situation when the warning is not + guaranteed to be raised (e.g. on importing a module) or if the warning + comes from third-party code. + + This function is also useful to prevent the same or similar warnings from + being issue twice due to recursive calls. + + When the warning is coming (reliably) from SymPy the warns function should + be preferred to ignore_warnings. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> import warnings + + Here's a warning: + + >>> with warnings.catch_warnings(): # reset warnings in doctest + ... warnings.simplefilter('error') + ... warnings.warn('deprecated', UserWarning) + Traceback (most recent call last): + ... + UserWarning: deprecated + + Let's suppress it with ignore_warnings: + + >>> with warnings.catch_warnings(): # reset warnings in doctest + ... warnings.simplefilter('error') + ... with ignore_warnings(UserWarning): + ... warnings.warn('deprecated', UserWarning) + + (No warning emitted) + + See Also + ======== + sympy.utilities.exceptions.SymPyDeprecationWarning + sympy.utilities.exceptions.sympy_deprecation_warning + sympy.utilities.decorator.deprecated + sympy.testing.pytest.warns_deprecated_sympy + + ''' + # Absorbs all warnings in warnrec + with warnings.catch_warnings(record=True) as warnrec: + # Make sure our warning doesn't get filtered + warnings.simplefilter("always", warningcls) + # Now run the test + yield + + # Reissue any warnings that we aren't testing for + for w in warnrec: + if not issubclass(w.category, warningcls): + warnings.warn_explicit(w.message, w.category, w.filename, w.lineno) diff --git a/.venv/lib/python3.13/site-packages/sympy/utilities/iterables.py b/.venv/lib/python3.13/site-packages/sympy/utilities/iterables.py new file mode 100644 index 0000000000000000000000000000000000000000..cc5c1152f03b10f62f9dd04c842b6fc74b0b9242 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/utilities/iterables.py @@ -0,0 +1,3179 @@ +from collections import Counter, defaultdict, OrderedDict +from itertools import ( + chain, combinations, combinations_with_replacement, cycle, islice, + permutations, product, groupby +) +# For backwards compatibility +from itertools import product as cartes # noqa: F401 +from operator import gt + + + +# this is the logical location of these functions +from sympy.utilities.enumerative import ( + multiset_partitions_taocp, list_visitor, MultisetPartitionTraverser) + +from sympy.utilities.misc import as_int +from sympy.utilities.decorator import deprecated + + +def is_palindromic(s, i=0, j=None): + """ + Return True if the sequence is the same from left to right as it + is from right to left in the whole sequence (default) or in the + Python slice ``s[i: j]``; else False. + + Examples + ======== + + >>> from sympy.utilities.iterables import is_palindromic + >>> is_palindromic([1, 0, 1]) + True + >>> is_palindromic('abcbb') + False + >>> is_palindromic('abcbb', 1) + False + + Normal Python slicing is performed in place so there is no need to + create a slice of the sequence for testing: + + >>> is_palindromic('abcbb', 1, -1) + True + >>> is_palindromic('abcbb', -4, -1) + True + + See Also + ======== + + sympy.ntheory.digits.is_palindromic: tests integers + + """ + i, j, _ = slice(i, j).indices(len(s)) + m = (j - i)//2 + # if length is odd, middle element will be ignored + return all(s[i + k] == s[j - 1 - k] for k in range(m)) + + +def flatten(iterable, levels=None, cls=None): # noqa: F811 + """ + Recursively denest iterable containers. + + >>> from sympy import flatten + + >>> flatten([1, 2, 3]) + [1, 2, 3] + >>> flatten([1, 2, [3]]) + [1, 2, 3] + >>> flatten([1, [2, 3], [4, 5]]) + [1, 2, 3, 4, 5] + >>> flatten([1.0, 2, (1, None)]) + [1.0, 2, 1, None] + + If you want to denest only a specified number of levels of + nested containers, then set ``levels`` flag to the desired + number of levels:: + + >>> ls = [[(-2, -1), (1, 2)], [(0, 0)]] + + >>> flatten(ls, levels=1) + [(-2, -1), (1, 2), (0, 0)] + + If cls argument is specified, it will only flatten instances of that + class, for example: + + >>> from sympy import Basic, S + >>> class MyOp(Basic): + ... pass + ... + >>> flatten([MyOp(S(1), MyOp(S(2), S(3)))], cls=MyOp) + [1, 2, 3] + + adapted from https://kogs-www.informatik.uni-hamburg.de/~meine/python_tricks + """ + from sympy.tensor.array import NDimArray + if levels is not None: + if not levels: + return iterable + elif levels > 0: + levels -= 1 + else: + raise ValueError( + "expected non-negative number of levels, got %s" % levels) + + if cls is None: + def reducible(x): + return is_sequence(x, set) + else: + def reducible(x): + return isinstance(x, cls) + + result = [] + + for el in iterable: + if reducible(el): + if hasattr(el, 'args') and not isinstance(el, NDimArray): + el = el.args + result.extend(flatten(el, levels=levels, cls=cls)) + else: + result.append(el) + + return result + + +def unflatten(iter, n=2): + """Group ``iter`` into tuples of length ``n``. Raise an error if + the length of ``iter`` is not a multiple of ``n``. + """ + if n < 1 or len(iter) % n: + raise ValueError('iter length is not a multiple of %i' % n) + return list(zip(*(iter[i::n] for i in range(n)))) + + +def reshape(seq, how): + """Reshape the sequence according to the template in ``how``. + + Examples + ======== + + >>> from sympy.utilities import reshape + >>> seq = list(range(1, 9)) + + >>> reshape(seq, [4]) # lists of 4 + [[1, 2, 3, 4], [5, 6, 7, 8]] + + >>> reshape(seq, (4,)) # tuples of 4 + [(1, 2, 3, 4), (5, 6, 7, 8)] + + >>> reshape(seq, (2, 2)) # tuples of 4 + [(1, 2, 3, 4), (5, 6, 7, 8)] + + >>> reshape(seq, (2, [2])) # (i, i, [i, i]) + [(1, 2, [3, 4]), (5, 6, [7, 8])] + + >>> reshape(seq, ((2,), [2])) # etc.... + [((1, 2), [3, 4]), ((5, 6), [7, 8])] + + >>> reshape(seq, (1, [2], 1)) + [(1, [2, 3], 4), (5, [6, 7], 8)] + + >>> reshape(tuple(seq), ([[1], 1, (2,)],)) + (([[1], 2, (3, 4)],), ([[5], 6, (7, 8)],)) + + >>> reshape(tuple(seq), ([1], 1, (2,))) + (([1], 2, (3, 4)), ([5], 6, (7, 8))) + + >>> reshape(list(range(12)), [2, [3], {2}, (1, (3,), 1)]) + [[0, 1, [2, 3, 4], {5, 6}, (7, (8, 9, 10), 11)]] + + """ + m = sum(flatten(how)) + n, rem = divmod(len(seq), m) + if m < 0 or rem: + raise ValueError('template must sum to positive number ' + 'that divides the length of the sequence') + i = 0 + container = type(how) + rv = [None]*n + for k in range(len(rv)): + _rv = [] + for hi in how: + if isinstance(hi, int): + _rv.extend(seq[i: i + hi]) + i += hi + else: + n = sum(flatten(hi)) + hi_type = type(hi) + _rv.append(hi_type(reshape(seq[i: i + n], hi)[0])) + i += n + rv[k] = container(_rv) + return type(seq)(rv) + + +def group(seq, multiple=True): + """ + Splits a sequence into a list of lists of equal, adjacent elements. + + Examples + ======== + + >>> from sympy import group + + >>> group([1, 1, 1, 2, 2, 3]) + [[1, 1, 1], [2, 2], [3]] + >>> group([1, 1, 1, 2, 2, 3], multiple=False) + [(1, 3), (2, 2), (3, 1)] + >>> group([1, 1, 3, 2, 2, 1], multiple=False) + [(1, 2), (3, 1), (2, 2), (1, 1)] + + See Also + ======== + + multiset + + """ + if multiple: + return [(list(g)) for _, g in groupby(seq)] + return [(k, len(list(g))) for k, g in groupby(seq)] + + +def _iproduct2(iterable1, iterable2): + '''Cartesian product of two possibly infinite iterables''' + + it1 = iter(iterable1) + it2 = iter(iterable2) + + elems1 = [] + elems2 = [] + + sentinel = object() + def append(it, elems): + e = next(it, sentinel) + if e is not sentinel: + elems.append(e) + + n = 0 + append(it1, elems1) + append(it2, elems2) + + while n <= len(elems1) + len(elems2): + for m in range(n-len(elems1)+1, len(elems2)): + yield (elems1[n-m], elems2[m]) + n += 1 + append(it1, elems1) + append(it2, elems2) + + +def iproduct(*iterables): + ''' + Cartesian product of iterables. + + Generator of the Cartesian product of iterables. This is analogous to + itertools.product except that it works with infinite iterables and will + yield any item from the infinite product eventually. + + Examples + ======== + + >>> from sympy.utilities.iterables import iproduct + >>> sorted(iproduct([1,2], [3,4])) + [(1, 3), (1, 4), (2, 3), (2, 4)] + + With an infinite iterator: + + >>> from sympy import S + >>> (3,) in iproduct(S.Integers) + True + >>> (3, 4) in iproduct(S.Integers, S.Integers) + True + + .. seealso:: + + `itertools.product + `_ + ''' + if len(iterables) == 0: + yield () + return + elif len(iterables) == 1: + for e in iterables[0]: + yield (e,) + elif len(iterables) == 2: + yield from _iproduct2(*iterables) + else: + first, others = iterables[0], iterables[1:] + for ef, eo in _iproduct2(first, iproduct(*others)): + yield (ef,) + eo + + +def multiset(seq): + """Return the hashable sequence in multiset form with values being the + multiplicity of the item in the sequence. + + Examples + ======== + + >>> from sympy.utilities.iterables import multiset + >>> multiset('mississippi') + {'i': 4, 'm': 1, 'p': 2, 's': 4} + + See Also + ======== + + group + + """ + return dict(Counter(seq).items()) + + + + +def ibin(n, bits=None, str=False): + """Return a list of length ``bits`` corresponding to the binary value + of ``n`` with small bits to the right (last). If bits is omitted, the + length will be the number required to represent ``n``. If the bits are + desired in reversed order, use the ``[::-1]`` slice of the returned list. + + If a sequence of all bits-length lists starting from ``[0, 0,..., 0]`` + through ``[1, 1, ..., 1]`` are desired, pass a non-integer for bits, e.g. + ``'all'``. + + If the bit *string* is desired pass ``str=True``. + + Examples + ======== + + >>> from sympy.utilities.iterables import ibin + >>> ibin(2) + [1, 0] + >>> ibin(2, 4) + [0, 0, 1, 0] + + If all lists corresponding to 0 to 2**n - 1, pass a non-integer + for bits: + + >>> bits = 2 + >>> for i in ibin(2, 'all'): + ... print(i) + (0, 0) + (0, 1) + (1, 0) + (1, 1) + + If a bit string is desired of a given length, use str=True: + + >>> n = 123 + >>> bits = 10 + >>> ibin(n, bits, str=True) + '0001111011' + >>> ibin(n, bits, str=True)[::-1] # small bits left + '1101111000' + >>> list(ibin(3, 'all', str=True)) + ['000', '001', '010', '011', '100', '101', '110', '111'] + + """ + if n < 0: + raise ValueError("negative numbers are not allowed") + n = as_int(n) + + if bits is None: + bits = 0 + else: + try: + bits = as_int(bits) + except ValueError: + bits = -1 + else: + if n.bit_length() > bits: + raise ValueError( + "`bits` must be >= {}".format(n.bit_length())) + + if not str: + if bits >= 0: + return [1 if i == "1" else 0 for i in bin(n)[2:].rjust(bits, "0")] + else: + return variations(range(2), n, repetition=True) + else: + if bits >= 0: + return bin(n)[2:].rjust(bits, "0") + else: + return (bin(i)[2:].rjust(n, "0") for i in range(2**n)) + + +def variations(seq, n, repetition=False): + r"""Returns an iterator over the n-sized variations of ``seq`` (size N). + ``repetition`` controls whether items in ``seq`` can appear more than once; + + Examples + ======== + + ``variations(seq, n)`` will return `\frac{N!}{(N - n)!}` permutations without + repetition of ``seq``'s elements: + + >>> from sympy import variations + >>> list(variations([1, 2], 2)) + [(1, 2), (2, 1)] + + ``variations(seq, n, True)`` will return the `N^n` permutations obtained + by allowing repetition of elements: + + >>> list(variations([1, 2], 2, repetition=True)) + [(1, 1), (1, 2), (2, 1), (2, 2)] + + If you ask for more items than are in the set you get the empty set unless + you allow repetitions: + + >>> list(variations([0, 1], 3, repetition=False)) + [] + >>> list(variations([0, 1], 3, repetition=True))[:4] + [(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1)] + + .. seealso:: + + `itertools.permutations + `_, + `itertools.product + `_ + """ + if not repetition: + seq = tuple(seq) + if len(seq) < n: + return iter(()) # 0 length iterator + return permutations(seq, n) + else: + if n == 0: + return iter(((),)) # yields 1 empty tuple + else: + return product(seq, repeat=n) + + +def subsets(seq, k=None, repetition=False): + r"""Generates all `k`-subsets (combinations) from an `n`-element set, ``seq``. + + A `k`-subset of an `n`-element set is any subset of length exactly `k`. The + number of `k`-subsets of an `n`-element set is given by ``binomial(n, k)``, + whereas there are `2^n` subsets all together. If `k` is ``None`` then all + `2^n` subsets will be returned from shortest to longest. + + Examples + ======== + + >>> from sympy import subsets + + ``subsets(seq, k)`` will return the + `\frac{n!}{k!(n - k)!}` `k`-subsets (combinations) + without repetition, i.e. once an item has been removed, it can no + longer be "taken": + + >>> list(subsets([1, 2], 2)) + [(1, 2)] + >>> list(subsets([1, 2])) + [(), (1,), (2,), (1, 2)] + >>> list(subsets([1, 2, 3], 2)) + [(1, 2), (1, 3), (2, 3)] + + + ``subsets(seq, k, repetition=True)`` will return the + `\frac{(n - 1 + k)!}{k!(n - 1)!}` + combinations *with* repetition: + + >>> list(subsets([1, 2], 2, repetition=True)) + [(1, 1), (1, 2), (2, 2)] + + If you ask for more items than are in the set you get the empty set unless + you allow repetitions: + + >>> list(subsets([0, 1], 3, repetition=False)) + [] + >>> list(subsets([0, 1], 3, repetition=True)) + [(0, 0, 0), (0, 0, 1), (0, 1, 1), (1, 1, 1)] + + """ + if k is None: + if not repetition: + return chain.from_iterable((combinations(seq, k) + for k in range(len(seq) + 1))) + else: + return chain.from_iterable((combinations_with_replacement(seq, k) + for k in range(len(seq) + 1))) + else: + if not repetition: + return combinations(seq, k) + else: + return combinations_with_replacement(seq, k) + + +def filter_symbols(iterator, exclude): + """ + Only yield elements from `iterator` that do not occur in `exclude`. + + Parameters + ========== + + iterator : iterable + iterator to take elements from + + exclude : iterable + elements to exclude + + Returns + ======= + + iterator : iterator + filtered iterator + """ + exclude = set(exclude) + for s in iterator: + if s not in exclude: + yield s + +def numbered_symbols(prefix='x', cls=None, start=0, exclude=(), *args, **assumptions): + """ + Generate an infinite stream of Symbols consisting of a prefix and + increasing subscripts provided that they do not occur in ``exclude``. + + Parameters + ========== + + prefix : str, optional + The prefix to use. By default, this function will generate symbols of + the form "x0", "x1", etc. + + cls : class, optional + The class to use. By default, it uses ``Symbol``, but you can also use ``Wild`` + or ``Dummy``. + + start : int, optional + The start number. By default, it is 0. + + exclude : list, tuple, set of cls, optional + Symbols to be excluded. + + *args, **kwargs + Additional positional and keyword arguments are passed to the *cls* class. + + Returns + ======= + + sym : Symbol + The subscripted symbols. + """ + exclude = set(exclude or []) + if cls is None: + # We can't just make the default cls=Symbol because it isn't + # imported yet. + from sympy.core import Symbol + cls = Symbol + + while True: + name = '%s%s' % (prefix, start) + s = cls(name, *args, **assumptions) + if s not in exclude: + yield s + start += 1 + + +def capture(func): + """Return the printed output of func(). + + ``func`` should be a function without arguments that produces output with + print statements. + + >>> from sympy.utilities.iterables import capture + >>> from sympy import pprint + >>> from sympy.abc import x + >>> def foo(): + ... print('hello world!') + ... + >>> 'hello' in capture(foo) # foo, not foo() + True + >>> capture(lambda: pprint(2/x)) + '2\\n-\\nx\\n' + + """ + from io import StringIO + import sys + + stdout = sys.stdout + sys.stdout = file = StringIO() + try: + func() + finally: + sys.stdout = stdout + return file.getvalue() + + +def sift(seq, keyfunc, binary=False): + """ + Sift the sequence, ``seq`` according to ``keyfunc``. + + Returns + ======= + + When ``binary`` is ``False`` (default), the output is a dictionary + where elements of ``seq`` are stored in a list keyed to the value + of keyfunc for that element. If ``binary`` is True then a tuple + with lists ``T`` and ``F`` are returned where ``T`` is a list + containing elements of seq for which ``keyfunc`` was ``True`` and + ``F`` containing those elements for which ``keyfunc`` was ``False``; + a ValueError is raised if the ``keyfunc`` is not binary. + + Examples + ======== + + >>> from sympy.utilities import sift + >>> from sympy.abc import x, y + >>> from sympy import sqrt, exp, pi, Tuple + + >>> sift(range(5), lambda x: x % 2) + {0: [0, 2, 4], 1: [1, 3]} + + sift() returns a defaultdict() object, so any key that has no matches will + give []. + + >>> sift([x], lambda x: x.is_commutative) + {True: [x]} + >>> _[False] + [] + + Sometimes you will not know how many keys you will get: + + >>> sift([sqrt(x), exp(x), (y**x)**2], + ... lambda x: x.as_base_exp()[0]) + {E: [exp(x)], x: [sqrt(x)], y: [y**(2*x)]} + + Sometimes you expect the results to be binary; the + results can be unpacked by setting ``binary`` to True: + + >>> sift(range(4), lambda x: x % 2, binary=True) + ([1, 3], [0, 2]) + >>> sift(Tuple(1, pi), lambda x: x.is_rational, binary=True) + ([1], [pi]) + + A ValueError is raised if the predicate was not actually binary + (which is a good test for the logic where sifting is used and + binary results were expected): + + >>> unknown = exp(1) - pi # the rationality of this is unknown + >>> args = Tuple(1, pi, unknown) + >>> sift(args, lambda x: x.is_rational, binary=True) + Traceback (most recent call last): + ... + ValueError: keyfunc gave non-binary output + + The non-binary sifting shows that there were 3 keys generated: + + >>> set(sift(args, lambda x: x.is_rational).keys()) + {None, False, True} + + If you need to sort the sifted items it might be better to use + ``ordered`` which can economically apply multiple sort keys + to a sequence while sorting. + + See Also + ======== + + ordered + + """ + if not binary: + m = defaultdict(list) + for i in seq: + m[keyfunc(i)].append(i) + return m + sift = F, T = [], [] + for i in seq: + try: + sift[keyfunc(i)].append(i) + except (IndexError, TypeError): + raise ValueError('keyfunc gave non-binary output') + return T, F + + +def take(iter, n): + """Return ``n`` items from ``iter`` iterator. """ + return [ value for _, value in zip(range(n), iter) ] + + +def dict_merge(*dicts): + """Merge dictionaries into a single dictionary. """ + merged = {} + + for dict in dicts: + merged.update(dict) + + return merged + + +def common_prefix(*seqs): + """Return the subsequence that is a common start of sequences in ``seqs``. + + >>> from sympy.utilities.iterables import common_prefix + >>> common_prefix(list(range(3))) + [0, 1, 2] + >>> common_prefix(list(range(3)), list(range(4))) + [0, 1, 2] + >>> common_prefix([1, 2, 3], [1, 2, 5]) + [1, 2] + >>> common_prefix([1, 2, 3], [1, 3, 5]) + [1] + """ + if not all(seqs): + return [] + elif len(seqs) == 1: + return seqs[0] + i = 0 + for i in range(min(len(s) for s in seqs)): + if not all(seqs[j][i] == seqs[0][i] for j in range(len(seqs))): + break + else: + i += 1 + return seqs[0][:i] + + +def common_suffix(*seqs): + """Return the subsequence that is a common ending of sequences in ``seqs``. + + >>> from sympy.utilities.iterables import common_suffix + >>> common_suffix(list(range(3))) + [0, 1, 2] + >>> common_suffix(list(range(3)), list(range(4))) + [] + >>> common_suffix([1, 2, 3], [9, 2, 3]) + [2, 3] + >>> common_suffix([1, 2, 3], [9, 7, 3]) + [3] + """ + + if not all(seqs): + return [] + elif len(seqs) == 1: + return seqs[0] + i = 0 + for i in range(-1, -min(len(s) for s in seqs) - 1, -1): + if not all(seqs[j][i] == seqs[0][i] for j in range(len(seqs))): + break + else: + i -= 1 + if i == -1: + return [] + else: + return seqs[0][i + 1:] + + +def prefixes(seq): + """ + Generate all prefixes of a sequence. + + Examples + ======== + + >>> from sympy.utilities.iterables import prefixes + + >>> list(prefixes([1,2,3,4])) + [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] + + """ + n = len(seq) + + for i in range(n): + yield seq[:i + 1] + + +def postfixes(seq): + """ + Generate all postfixes of a sequence. + + Examples + ======== + + >>> from sympy.utilities.iterables import postfixes + + >>> list(postfixes([1,2,3,4])) + [[4], [3, 4], [2, 3, 4], [1, 2, 3, 4]] + + """ + n = len(seq) + + for i in range(n): + yield seq[n - i - 1:] + + +def topological_sort(graph, key=None): + r""" + Topological sort of graph's vertices. + + Parameters + ========== + + graph : tuple[list, list[tuple[T, T]] + A tuple consisting of a list of vertices and a list of edges of + a graph to be sorted topologically. + + key : callable[T] (optional) + Ordering key for vertices on the same level. By default the natural + (e.g. lexicographic) ordering is used (in this case the base type + must implement ordering relations). + + Examples + ======== + + Consider a graph:: + + +---+ +---+ +---+ + | 7 |\ | 5 | | 3 | + +---+ \ +---+ +---+ + | _\___/ ____ _/ | + | / \___/ \ / | + V V V V | + +----+ +---+ | + | 11 | | 8 | | + +----+ +---+ | + | | \____ ___/ _ | + | \ \ / / \ | + V \ V V / V V + +---+ \ +---+ | +----+ + | 2 | | | 9 | | | 10 | + +---+ | +---+ | +----+ + \________/ + + where vertices are integers. This graph can be encoded using + elementary Python's data structures as follows:: + + >>> V = [2, 3, 5, 7, 8, 9, 10, 11] + >>> E = [(7, 11), (7, 8), (5, 11), (3, 8), (3, 10), + ... (11, 2), (11, 9), (11, 10), (8, 9)] + + To compute a topological sort for graph ``(V, E)`` issue:: + + >>> from sympy.utilities.iterables import topological_sort + + >>> topological_sort((V, E)) + [3, 5, 7, 8, 11, 2, 9, 10] + + If specific tie breaking approach is needed, use ``key`` parameter:: + + >>> topological_sort((V, E), key=lambda v: -v) + [7, 5, 11, 3, 10, 8, 9, 2] + + Only acyclic graphs can be sorted. If the input graph has a cycle, + then ``ValueError`` will be raised:: + + >>> topological_sort((V, E + [(10, 7)])) + Traceback (most recent call last): + ... + ValueError: cycle detected + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Topological_sorting + + """ + V, E = graph + + L = [] + S = set(V) + E = list(E) + + S.difference_update(u for v, u in E) + + if key is None: + def key(value): + return value + + S = sorted(S, key=key, reverse=True) + + while S: + node = S.pop() + L.append(node) + + for u, v in list(E): + if u == node: + E.remove((u, v)) + + for _u, _v in E: + if v == _v: + break + else: + kv = key(v) + + for i, s in enumerate(S): + ks = key(s) + + if kv > ks: + S.insert(i, v) + break + else: + S.append(v) + + if E: + raise ValueError("cycle detected") + else: + return L + + +def strongly_connected_components(G): + r""" + Strongly connected components of a directed graph in reverse topological + order. + + + Parameters + ========== + + G : tuple[list, list[tuple[T, T]] + A tuple consisting of a list of vertices and a list of edges of + a graph whose strongly connected components are to be found. + + + Examples + ======== + + Consider a directed graph (in dot notation):: + + digraph { + A -> B + A -> C + B -> C + C -> B + B -> D + } + + .. graphviz:: + + digraph { + A -> B + A -> C + B -> C + C -> B + B -> D + } + + where vertices are the letters A, B, C and D. This graph can be encoded + using Python's elementary data structures as follows:: + + >>> V = ['A', 'B', 'C', 'D'] + >>> E = [('A', 'B'), ('A', 'C'), ('B', 'C'), ('C', 'B'), ('B', 'D')] + + The strongly connected components of this graph can be computed as + + >>> from sympy.utilities.iterables import strongly_connected_components + + >>> strongly_connected_components((V, E)) + [['D'], ['B', 'C'], ['A']] + + This also gives the components in reverse topological order. + + Since the subgraph containing B and C has a cycle they must be together in + a strongly connected component. A and D are connected to the rest of the + graph but not in a cyclic manner so they appear as their own strongly + connected components. + + + Notes + ===== + + The vertices of the graph must be hashable for the data structures used. + If the vertices are unhashable replace them with integer indices. + + This function uses Tarjan's algorithm to compute the strongly connected + components in `O(|V|+|E|)` (linear) time. + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Strongly_connected_component + .. [2] https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm + + + See Also + ======== + + sympy.utilities.iterables.connected_components + + """ + # Map from a vertex to its neighbours + V, E = G + Gmap = {vi: [] for vi in V} + for v1, v2 in E: + Gmap[v1].append(v2) + return _strongly_connected_components(V, Gmap) + + +def _strongly_connected_components(V, Gmap): + """More efficient internal routine for strongly_connected_components""" + # + # Here V is an iterable of vertices and Gmap is a dict mapping each vertex + # to a list of neighbours e.g.: + # + # V = [0, 1, 2, 3] + # Gmap = {0: [2, 3], 1: [0]} + # + # For a large graph these data structures can often be created more + # efficiently then those expected by strongly_connected_components() which + # in this case would be + # + # V = [0, 1, 2, 3] + # Gmap = [(0, 2), (0, 3), (1, 0)] + # + # XXX: Maybe this should be the recommended function to use instead... + # + + # Non-recursive Tarjan's algorithm: + lowlink = {} + indices = {} + stack = OrderedDict() + callstack = [] + components = [] + nomore = object() + + def start(v): + index = len(stack) + indices[v] = lowlink[v] = index + stack[v] = None + callstack.append((v, iter(Gmap[v]))) + + def finish(v1): + # Finished a component? + if lowlink[v1] == indices[v1]: + component = [stack.popitem()[0]] + while component[-1] is not v1: + component.append(stack.popitem()[0]) + components.append(component[::-1]) + v2, _ = callstack.pop() + if callstack: + v1, _ = callstack[-1] + lowlink[v1] = min(lowlink[v1], lowlink[v2]) + + for v in V: + if v in indices: + continue + start(v) + while callstack: + v1, it1 = callstack[-1] + v2 = next(it1, nomore) + # Finished children of v1? + if v2 is nomore: + finish(v1) + # Recurse on v2 + elif v2 not in indices: + start(v2) + elif v2 in stack: + lowlink[v1] = min(lowlink[v1], indices[v2]) + + # Reverse topological sort order: + return components + + +def connected_components(G): + r""" + Connected components of an undirected graph or weakly connected components + of a directed graph. + + + Parameters + ========== + + G : tuple[list, list[tuple[T, T]] + A tuple consisting of a list of vertices and a list of edges of + a graph whose connected components are to be found. + + + Examples + ======== + + + Given an undirected graph:: + + graph { + A -- B + C -- D + } + + .. graphviz:: + + graph { + A -- B + C -- D + } + + We can find the connected components using this function if we include + each edge in both directions:: + + >>> from sympy.utilities.iterables import connected_components + + >>> V = ['A', 'B', 'C', 'D'] + >>> E = [('A', 'B'), ('B', 'A'), ('C', 'D'), ('D', 'C')] + >>> connected_components((V, E)) + [['A', 'B'], ['C', 'D']] + + The weakly connected components of a directed graph can found the same + way. + + + Notes + ===== + + The vertices of the graph must be hashable for the data structures used. + If the vertices are unhashable replace them with integer indices. + + This function uses Tarjan's algorithm to compute the connected components + in `O(|V|+|E|)` (linear) time. + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Component_%28graph_theory%29 + .. [2] https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm + + + See Also + ======== + + sympy.utilities.iterables.strongly_connected_components + + """ + # Duplicate edges both ways so that the graph is effectively undirected + # and return the strongly connected components: + V, E = G + E_undirected = [] + for v1, v2 in E: + E_undirected.extend([(v1, v2), (v2, v1)]) + return strongly_connected_components((V, E_undirected)) + + +def rotate_left(x, y): + """ + Left rotates a list x by the number of steps specified + in y. + + Examples + ======== + + >>> from sympy.utilities.iterables import rotate_left + >>> a = [0, 1, 2] + >>> rotate_left(a, 1) + [1, 2, 0] + """ + if len(x) == 0: + return [] + y = y % len(x) + return x[y:] + x[:y] + + +def rotate_right(x, y): + """ + Right rotates a list x by the number of steps specified + in y. + + Examples + ======== + + >>> from sympy.utilities.iterables import rotate_right + >>> a = [0, 1, 2] + >>> rotate_right(a, 1) + [2, 0, 1] + """ + if len(x) == 0: + return [] + y = len(x) - y % len(x) + return x[y:] + x[:y] + + +def least_rotation(x, key=None): + ''' + Returns the number of steps of left rotation required to + obtain lexicographically minimal string/list/tuple, etc. + + Examples + ======== + + >>> from sympy.utilities.iterables import least_rotation, rotate_left + >>> a = [3, 1, 5, 1, 2] + >>> least_rotation(a) + 3 + >>> rotate_left(a, _) + [1, 2, 3, 1, 5] + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Lexicographically_minimal_string_rotation + + ''' + from sympy.functions.elementary.miscellaneous import Id + if key is None: key = Id + S = x + x # Concatenate string to it self to avoid modular arithmetic + f = [-1] * len(S) # Failure function + k = 0 # Least rotation of string found so far + for j in range(1,len(S)): + sj = S[j] + i = f[j-k-1] + while i != -1 and sj != S[k+i+1]: + if key(sj) < key(S[k+i+1]): + k = j-i-1 + i = f[i] + if sj != S[k+i+1]: + if key(sj) < key(S[k]): + k = j + f[j-k] = -1 + else: + f[j-k] = i+1 + return k + + +def multiset_combinations(m, n, g=None): + """ + Return the unique combinations of size ``n`` from multiset ``m``. + + Examples + ======== + + >>> from sympy.utilities.iterables import multiset_combinations + >>> from itertools import combinations + >>> [''.join(i) for i in multiset_combinations('baby', 3)] + ['abb', 'aby', 'bby'] + + >>> def count(f, s): return len(list(f(s, 3))) + + The number of combinations depends on the number of letters; the + number of unique combinations depends on how the letters are + repeated. + + >>> s1 = 'abracadabra' + >>> s2 = 'banana tree' + >>> count(combinations, s1), count(multiset_combinations, s1) + (165, 23) + >>> count(combinations, s2), count(multiset_combinations, s2) + (165, 54) + + """ + from sympy.core.sorting import ordered + if g is None: + if isinstance(m, dict): + if any(as_int(v) < 0 for v in m.values()): + raise ValueError('counts cannot be negative') + N = sum(m.values()) + if n > N: + return + g = [[k, m[k]] for k in ordered(m)] + else: + m = list(m) + N = len(m) + if n > N: + return + try: + m = multiset(m) + g = [(k, m[k]) for k in ordered(m)] + except TypeError: + m = list(ordered(m)) + g = [list(i) for i in group(m, multiple=False)] + del m + else: + # not checking counts since g is intended for internal use + N = sum(v for k, v in g) + if n > N or not n: + yield [] + else: + for i, (k, v) in enumerate(g): + if v >= n: + yield [k]*n + v = n - 1 + for v in range(min(n, v), 0, -1): + for j in multiset_combinations(None, n - v, g[i + 1:]): + rv = [k]*v + j + if len(rv) == n: + yield rv + +def multiset_permutations(m, size=None, g=None): + """ + Return the unique permutations of multiset ``m``. + + Examples + ======== + + >>> from sympy.utilities.iterables import multiset_permutations + >>> from sympy import factorial + >>> [''.join(i) for i in multiset_permutations('aab')] + ['aab', 'aba', 'baa'] + >>> factorial(len('banana')) + 720 + >>> len(list(multiset_permutations('banana'))) + 60 + """ + from sympy.core.sorting import ordered + if g is None: + if isinstance(m, dict): + if any(as_int(v) < 0 for v in m.values()): + raise ValueError('counts cannot be negative') + g = [[k, m[k]] for k in ordered(m)] + else: + m = list(ordered(m)) + g = [list(i) for i in group(m, multiple=False)] + del m + do = [gi for gi in g if gi[1] > 0] + SUM = sum(gi[1] for gi in do) + if not do or size is not None and (size > SUM or size < 1): + if not do and size is None or size == 0: + yield [] + return + elif size == 1: + for k, v in do: + yield [k] + elif len(do) == 1: + k, v = do[0] + v = v if size is None else (size if size <= v else 0) + yield [k for i in range(v)] + elif all(v == 1 for k, v in do): + for p in permutations([k for k, v in do], size): + yield list(p) + else: + size = size if size is not None else SUM + for i, (k, v) in enumerate(do): + do[i][1] -= 1 + for j in multiset_permutations(None, size - 1, do): + if j: + yield [k] + j + do[i][1] += 1 + + +def _partition(seq, vector, m=None): + """ + Return the partition of seq as specified by the partition vector. + + Examples + ======== + + >>> from sympy.utilities.iterables import _partition + >>> _partition('abcde', [1, 0, 1, 2, 0]) + [['b', 'e'], ['a', 'c'], ['d']] + + Specifying the number of bins in the partition is optional: + + >>> _partition('abcde', [1, 0, 1, 2, 0], 3) + [['b', 'e'], ['a', 'c'], ['d']] + + The output of _set_partitions can be passed as follows: + + >>> output = (3, [1, 0, 1, 2, 0]) + >>> _partition('abcde', *output) + [['b', 'e'], ['a', 'c'], ['d']] + + See Also + ======== + + combinatorics.partitions.Partition.from_rgs + + """ + if m is None: + m = max(vector) + 1 + elif isinstance(vector, int): # entered as m, vector + vector, m = m, vector + p = [[] for i in range(m)] + for i, v in enumerate(vector): + p[v].append(seq[i]) + return p + + +def _set_partitions(n): + """Cycle through all partitions of n elements, yielding the + current number of partitions, ``m``, and a mutable list, ``q`` + such that ``element[i]`` is in part ``q[i]`` of the partition. + + NOTE: ``q`` is modified in place and generally should not be changed + between function calls. + + Examples + ======== + + >>> from sympy.utilities.iterables import _set_partitions, _partition + >>> for m, q in _set_partitions(3): + ... print('%s %s %s' % (m, q, _partition('abc', q, m))) + 1 [0, 0, 0] [['a', 'b', 'c']] + 2 [0, 0, 1] [['a', 'b'], ['c']] + 2 [0, 1, 0] [['a', 'c'], ['b']] + 2 [0, 1, 1] [['a'], ['b', 'c']] + 3 [0, 1, 2] [['a'], ['b'], ['c']] + + Notes + ===== + + This algorithm is similar to, and solves the same problem as, + Algorithm 7.2.1.5H, from volume 4A of Knuth's The Art of Computer + Programming. Knuth uses the term "restricted growth string" where + this code refers to a "partition vector". In each case, the meaning is + the same: the value in the ith element of the vector specifies to + which part the ith set element is to be assigned. + + At the lowest level, this code implements an n-digit big-endian + counter (stored in the array q) which is incremented (with carries) to + get the next partition in the sequence. A special twist is that a + digit is constrained to be at most one greater than the maximum of all + the digits to the left of it. The array p maintains this maximum, so + that the code can efficiently decide when a digit can be incremented + in place or whether it needs to be reset to 0 and trigger a carry to + the next digit. The enumeration starts with all the digits 0 (which + corresponds to all the set elements being assigned to the same 0th + part), and ends with 0123...n, which corresponds to each set element + being assigned to a different, singleton, part. + + This routine was rewritten to use 0-based lists while trying to + preserve the beauty and efficiency of the original algorithm. + + References + ========== + + .. [1] Nijenhuis, Albert and Wilf, Herbert. (1978) Combinatorial Algorithms, + 2nd Ed, p 91, algorithm "nexequ". Available online from + https://www.math.upenn.edu/~wilf/website/CombAlgDownld.html (viewed + November 17, 2012). + + """ + p = [0]*n + q = [0]*n + nc = 1 + yield nc, q + while nc != n: + m = n + while 1: + m -= 1 + i = q[m] + if p[i] != 1: + break + q[m] = 0 + i += 1 + q[m] = i + m += 1 + nc += m - n + p[0] += n - m + if i == nc: + p[nc] = 0 + nc += 1 + p[i - 1] -= 1 + p[i] += 1 + yield nc, q + + +def multiset_partitions(multiset, m=None): + """ + Return unique partitions of the given multiset (in list form). + If ``m`` is None, all multisets will be returned, otherwise only + partitions with ``m`` parts will be returned. + + If ``multiset`` is an integer, a range [0, 1, ..., multiset - 1] + will be supplied. + + Examples + ======== + + >>> from sympy.utilities.iterables import multiset_partitions + >>> list(multiset_partitions([1, 2, 3, 4], 2)) + [[[1, 2, 3], [4]], [[1, 2, 4], [3]], [[1, 2], [3, 4]], + [[1, 3, 4], [2]], [[1, 3], [2, 4]], [[1, 4], [2, 3]], + [[1], [2, 3, 4]]] + >>> list(multiset_partitions([1, 2, 3, 4], 1)) + [[[1, 2, 3, 4]]] + + Only unique partitions are returned and these will be returned in a + canonical order regardless of the order of the input: + + >>> a = [1, 2, 2, 1] + >>> ans = list(multiset_partitions(a, 2)) + >>> a.sort() + >>> list(multiset_partitions(a, 2)) == ans + True + >>> a = range(3, 1, -1) + >>> (list(multiset_partitions(a)) == + ... list(multiset_partitions(sorted(a)))) + True + + If m is omitted then all partitions will be returned: + + >>> list(multiset_partitions([1, 1, 2])) + [[[1, 1, 2]], [[1, 1], [2]], [[1, 2], [1]], [[1], [1], [2]]] + >>> list(multiset_partitions([1]*3)) + [[[1, 1, 1]], [[1], [1, 1]], [[1], [1], [1]]] + + Counting + ======== + + The number of partitions of a set is given by the bell number: + + >>> from sympy import bell + >>> len(list(multiset_partitions(5))) == bell(5) == 52 + True + + The number of partitions of length k from a set of size n is given by the + Stirling Number of the 2nd kind: + + >>> from sympy.functions.combinatorial.numbers import stirling + >>> stirling(5, 2) == len(list(multiset_partitions(5, 2))) == 15 + True + + These comments on counting apply to *sets*, not multisets. + + Notes + ===== + + When all the elements are the same in the multiset, the order + of the returned partitions is determined by the ``partitions`` + routine. If one is counting partitions then it is better to use + the ``nT`` function. + + See Also + ======== + + partitions + sympy.combinatorics.partitions.Partition + sympy.combinatorics.partitions.IntegerPartition + sympy.functions.combinatorial.numbers.nT + + """ + # This function looks at the supplied input and dispatches to + # several special-case routines as they apply. + if isinstance(multiset, int): + n = multiset + if m and m > n: + return + multiset = list(range(n)) + if m == 1: + yield [multiset[:]] + return + + # If m is not None, it can sometimes be faster to use + # MultisetPartitionTraverser.enum_range() even for inputs + # which are sets. Since the _set_partitions code is quite + # fast, this is only advantageous when the overall set + # partitions outnumber those with the desired number of parts + # by a large factor. (At least 60.) Such a switch is not + # currently implemented. + for nc, q in _set_partitions(n): + if m is None or nc == m: + rv = [[] for i in range(nc)] + for i in range(n): + rv[q[i]].append(multiset[i]) + yield rv + return + + if len(multiset) == 1 and isinstance(multiset, str): + multiset = [multiset] + + if not has_variety(multiset): + # Only one component, repeated n times. The resulting + # partitions correspond to partitions of integer n. + n = len(multiset) + if m and m > n: + return + if m == 1: + yield [multiset[:]] + return + x = multiset[:1] + for size, p in partitions(n, m, size=True): + if m is None or size == m: + rv = [] + for k in sorted(p): + rv.extend([x*k]*p[k]) + yield rv + else: + from sympy.core.sorting import ordered + multiset = list(ordered(multiset)) + n = len(multiset) + if m and m > n: + return + if m == 1: + yield [multiset[:]] + return + + # Split the information of the multiset into two lists - + # one of the elements themselves, and one (of the same length) + # giving the number of repeats for the corresponding element. + elements, multiplicities = zip(*group(multiset, False)) + + if len(elements) < len(multiset): + # General case - multiset with more than one distinct element + # and at least one element repeated more than once. + if m: + mpt = MultisetPartitionTraverser() + for state in mpt.enum_range(multiplicities, m-1, m): + yield list_visitor(state, elements) + else: + for state in multiset_partitions_taocp(multiplicities): + yield list_visitor(state, elements) + else: + # Set partitions case - no repeated elements. Pretty much + # same as int argument case above, with same possible, but + # currently unimplemented optimization for some cases when + # m is not None + for nc, q in _set_partitions(n): + if m is None or nc == m: + rv = [[] for i in range(nc)] + for i in range(n): + rv[q[i]].append(i) + yield [[multiset[j] for j in i] for i in rv] + + +def partitions(n, m=None, k=None, size=False): + """Generate all partitions of positive integer, n. + + Each partition is represented as a dictionary, mapping an integer + to the number of copies of that integer in the partition. For example, + the first partition of 4 returned is {4: 1}, "4: one of them". + + Parameters + ========== + n : int + m : int, optional + limits number of parts in partition (mnemonic: m, maximum parts) + k : int, optional + limits the numbers that are kept in the partition (mnemonic: k, keys) + size : bool, default: False + If ``True``, (M, P) is returned where M is the sum of the + multiplicities and P is the generated partition. + If ``False``, only the generated partition is returned. + + Examples + ======== + + >>> from sympy.utilities.iterables import partitions + + The numbers appearing in the partition (the key of the returned dict) + are limited with k: + + >>> for p in partitions(6, k=2): # doctest: +SKIP + ... print(p) + {2: 3} + {1: 2, 2: 2} + {1: 4, 2: 1} + {1: 6} + + The maximum number of parts in the partition (the sum of the values in + the returned dict) are limited with m (default value, None, gives + partitions from 1 through n): + + >>> for p in partitions(6, m=2): # doctest: +SKIP + ... print(p) + ... + {6: 1} + {1: 1, 5: 1} + {2: 1, 4: 1} + {3: 2} + + References + ========== + + .. [1] modified from Tim Peter's version to allow for k and m values: + https://code.activestate.com/recipes/218332-generator-for-integer-partitions/ + + See Also + ======== + + sympy.combinatorics.partitions.Partition + sympy.combinatorics.partitions.IntegerPartition + + """ + if (n <= 0 or + m is not None and m < 1 or + k is not None and k < 1 or + m and k and m*k < n): + # the empty set is the only way to handle these inputs + # and returning {} to represent it is consistent with + # the counting convention, e.g. nT(0) == 1. + if size: + yield 0, {} + else: + yield {} + return + + if m is None: + m = n + else: + m = min(m, n) + k = min(k or n, n) + + n, m, k = as_int(n), as_int(m), as_int(k) + q, r = divmod(n, k) + ms = {k: q} + keys = [k] # ms.keys(), from largest to smallest + if r: + ms[r] = 1 + keys.append(r) + room = m - q - bool(r) + if size: + yield sum(ms.values()), ms.copy() + else: + yield ms.copy() + + while keys != [1]: + # Reuse any 1's. + if keys[-1] == 1: + del keys[-1] + reuse = ms.pop(1) + room += reuse + else: + reuse = 0 + + while 1: + # Let i be the smallest key larger than 1. Reuse one + # instance of i. + i = keys[-1] + newcount = ms[i] = ms[i] - 1 + reuse += i + if newcount == 0: + del keys[-1], ms[i] + room += 1 + + # Break the remainder into pieces of size i-1. + i -= 1 + q, r = divmod(reuse, i) + need = q + bool(r) + if need > room: + if not keys: + return + continue + + ms[i] = q + keys.append(i) + if r: + ms[r] = 1 + keys.append(r) + break + room -= need + if size: + yield sum(ms.values()), ms.copy() + else: + yield ms.copy() + + +def ordered_partitions(n, m=None, sort=True): + """Generates ordered partitions of integer *n*. + + Parameters + ========== + n : int + m : int, optional + The default value gives partitions of all sizes else only + those with size m. In addition, if *m* is not None then + partitions are generated *in place* (see examples). + sort : bool, default: True + Controls whether partitions are + returned in sorted order when *m* is not None; when False, + the partitions are returned as fast as possible with elements + sorted, but when m|n the partitions will not be in + ascending lexicographical order. + + Examples + ======== + + >>> from sympy.utilities.iterables import ordered_partitions + + All partitions of 5 in ascending lexicographical: + + >>> for p in ordered_partitions(5): + ... print(p) + [1, 1, 1, 1, 1] + [1, 1, 1, 2] + [1, 1, 3] + [1, 2, 2] + [1, 4] + [2, 3] + [5] + + Only partitions of 5 with two parts: + + >>> for p in ordered_partitions(5, 2): + ... print(p) + [1, 4] + [2, 3] + + When ``m`` is given, a given list objects will be used more than + once for speed reasons so you will not see the correct partitions + unless you make a copy of each as it is generated: + + >>> [p for p in ordered_partitions(7, 3)] + [[1, 1, 1], [1, 1, 1], [1, 1, 1], [2, 2, 2]] + >>> [list(p) for p in ordered_partitions(7, 3)] + [[1, 1, 5], [1, 2, 4], [1, 3, 3], [2, 2, 3]] + + When ``n`` is a multiple of ``m``, the elements are still sorted + but the partitions themselves will be *unordered* if sort is False; + the default is to return them in ascending lexicographical order. + + >>> for p in ordered_partitions(6, 2): + ... print(p) + [1, 5] + [2, 4] + [3, 3] + + But if speed is more important than ordering, sort can be set to + False: + + >>> for p in ordered_partitions(6, 2, sort=False): + ... print(p) + [1, 5] + [3, 3] + [2, 4] + + References + ========== + + .. [1] Generating Integer Partitions, [online], + Available: https://jeromekelleher.net/generating-integer-partitions.html + .. [2] Jerome Kelleher and Barry O'Sullivan, "Generating All + Partitions: A Comparison Of Two Encodings", [online], + Available: https://arxiv.org/pdf/0909.2331v2.pdf + """ + if n < 1 or m is not None and m < 1: + # the empty set is the only way to handle these inputs + # and returning {} to represent it is consistent with + # the counting convention, e.g. nT(0) == 1. + yield [] + return + + if m is None: + # The list `a`'s leading elements contain the partition in which + # y is the biggest element and x is either the same as y or the + # 2nd largest element; v and w are adjacent element indices + # to which x and y are being assigned, respectively. + a = [1]*n + y = -1 + v = n + while v > 0: + v -= 1 + x = a[v] + 1 + while y >= 2 * x: + a[v] = x + y -= x + v += 1 + w = v + 1 + while x <= y: + a[v] = x + a[w] = y + yield a[:w + 1] + x += 1 + y -= 1 + a[v] = x + y + y = a[v] - 1 + yield a[:w] + elif m == 1: + yield [n] + elif n == m: + yield [1]*n + else: + # recursively generate partitions of size m + for b in range(1, n//m + 1): + a = [b]*m + x = n - b*m + if not x: + if sort: + yield a + elif not sort and x <= m: + for ax in ordered_partitions(x, sort=False): + mi = len(ax) + a[-mi:] = [i + b for i in ax] + yield a + a[-mi:] = [b]*mi + else: + for mi in range(1, m): + for ax in ordered_partitions(x, mi, sort=True): + a[-mi:] = [i + b for i in ax] + yield a + a[-mi:] = [b]*mi + + +def binary_partitions(n): + """ + Generates the binary partition of *n*. + + A binary partition consists only of numbers that are + powers of two. Each step reduces a `2^{k+1}` to `2^k` and + `2^k`. Thus 16 is converted to 8 and 8. + + Examples + ======== + + >>> from sympy.utilities.iterables import binary_partitions + >>> for i in binary_partitions(5): + ... print(i) + ... + [4, 1] + [2, 2, 1] + [2, 1, 1, 1] + [1, 1, 1, 1, 1] + + References + ========== + + .. [1] TAOCP 4, section 7.2.1.5, problem 64 + + """ + from math import ceil, log2 + power = int(2**(ceil(log2(n)))) + acc = 0 + partition = [] + while power: + if acc + power <= n: + partition.append(power) + acc += power + power >>= 1 + + last_num = len(partition) - 1 - (n & 1) + while last_num >= 0: + yield partition + if partition[last_num] == 2: + partition[last_num] = 1 + partition.append(1) + last_num -= 1 + continue + partition.append(1) + partition[last_num] >>= 1 + x = partition[last_num + 1] = partition[last_num] + last_num += 1 + while x > 1: + if x <= len(partition) - last_num - 1: + del partition[-x + 1:] + last_num += 1 + partition[last_num] = x + else: + x >>= 1 + yield [1]*n + + +def has_dups(seq): + """Return True if there are any duplicate elements in ``seq``. + + Examples + ======== + + >>> from sympy import has_dups, Dict, Set + >>> has_dups((1, 2, 1)) + True + >>> has_dups(range(3)) + False + >>> all(has_dups(c) is False for c in (set(), Set(), dict(), Dict())) + True + """ + from sympy.core.containers import Dict + from sympy.sets.sets import Set + if isinstance(seq, (dict, set, Dict, Set)): + return False + unique = set() + try: + return any(True for s in seq if s in unique or unique.add(s)) + except TypeError: + return len(seq) != len(list(uniq(seq))) + + +def has_variety(seq): + """Return True if there are any different elements in ``seq``. + + Examples + ======== + + >>> from sympy import has_variety + + >>> has_variety((1, 2, 1)) + True + >>> has_variety((1, 1, 1)) + False + """ + for i, s in enumerate(seq): + if i == 0: + sentinel = s + else: + if s != sentinel: + return True + return False + + +def uniq(seq, result=None): + """ + Yield unique elements from ``seq`` as an iterator. The second + parameter ``result`` is used internally; it is not necessary + to pass anything for this. + + Note: changing the sequence during iteration will raise a + RuntimeError if the size of the sequence is known; if you pass + an iterator and advance the iterator you will change the + output of this routine but there will be no warning. + + Examples + ======== + + >>> from sympy.utilities.iterables import uniq + >>> dat = [1, 4, 1, 5, 4, 2, 1, 2] + >>> type(uniq(dat)) in (list, tuple) + False + + >>> list(uniq(dat)) + [1, 4, 5, 2] + >>> list(uniq(x for x in dat)) + [1, 4, 5, 2] + >>> list(uniq([[1], [2, 1], [1]])) + [[1], [2, 1]] + """ + try: + n = len(seq) + except TypeError: + n = None + def check(): + # check that size of seq did not change during iteration; + # if n == None the object won't support size changing, e.g. + # an iterator can't be changed + if n is not None and len(seq) != n: + raise RuntimeError('sequence changed size during iteration') + try: + seen = set() + result = result or [] + for i, s in enumerate(seq): + if not (s in seen or seen.add(s)): + yield s + check() + except TypeError: + if s not in result: + yield s + check() + result.append(s) + if hasattr(seq, '__getitem__'): + yield from uniq(seq[i + 1:], result) + else: + yield from uniq(seq, result) + + +def generate_bell(n): + """Return permutations of [0, 1, ..., n - 1] such that each permutation + differs from the last by the exchange of a single pair of neighbors. + The ``n!`` permutations are returned as an iterator. In order to obtain + the next permutation from a random starting permutation, use the + ``next_trotterjohnson`` method of the Permutation class (which generates + the same sequence in a different manner). + + Examples + ======== + + >>> from itertools import permutations + >>> from sympy.utilities.iterables import generate_bell + >>> from sympy import zeros, Matrix + + This is the sort of permutation used in the ringing of physical bells, + and does not produce permutations in lexicographical order. Rather, the + permutations differ from each other by exactly one inversion, and the + position at which the swapping occurs varies periodically in a simple + fashion. Consider the first few permutations of 4 elements generated + by ``permutations`` and ``generate_bell``: + + >>> list(permutations(range(4)))[:5] + [(0, 1, 2, 3), (0, 1, 3, 2), (0, 2, 1, 3), (0, 2, 3, 1), (0, 3, 1, 2)] + >>> list(generate_bell(4))[:5] + [(0, 1, 2, 3), (0, 1, 3, 2), (0, 3, 1, 2), (3, 0, 1, 2), (3, 0, 2, 1)] + + Notice how the 2nd and 3rd lexicographical permutations have 3 elements + out of place whereas each "bell" permutation always has only two + elements out of place relative to the previous permutation (and so the + signature (+/-1) of a permutation is opposite of the signature of the + previous permutation). + + How the position of inversion varies across the elements can be seen + by tracing out where the largest number appears in the permutations: + + >>> m = zeros(4, 24) + >>> for i, p in enumerate(generate_bell(4)): + ... m[:, i] = Matrix([j - 3 for j in list(p)]) # make largest zero + >>> m.print_nonzero('X') + [XXX XXXXXX XXXXXX XXX] + [XX XX XXXX XX XXXX XX XX] + [X XXXX XX XXXX XX XXXX X] + [ XXXXXX XXXXXX XXXXXX ] + + See Also + ======== + + sympy.combinatorics.permutations.Permutation.next_trotterjohnson + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Method_ringing + + .. [2] https://stackoverflow.com/questions/4856615/recursive-permutation/4857018 + + .. [3] https://web.archive.org/web/20160313023044/http://programminggeeks.com/bell-algorithm-for-permutation/ + + .. [4] https://en.wikipedia.org/wiki/Steinhaus%E2%80%93Johnson%E2%80%93Trotter_algorithm + + .. [5] Generating involutions, derangements, and relatives by ECO + Vincent Vajnovszki, DMTCS vol 1 issue 12, 2010 + + """ + n = as_int(n) + if n < 1: + raise ValueError('n must be a positive integer') + if n == 1: + yield (0,) + elif n == 2: + yield (0, 1) + yield (1, 0) + elif n == 3: + yield from [(0, 1, 2), (0, 2, 1), (2, 0, 1), (2, 1, 0), (1, 2, 0), (1, 0, 2)] + else: + m = n - 1 + op = [0] + [-1]*m + l = list(range(n)) + while True: + yield tuple(l) + # find biggest element with op + big = None, -1 # idx, value + for i in range(n): + if op[i] and l[i] > big[1]: + big = i, l[i] + i, _ = big + if i is None: + break # there are no ops left + # swap it with neighbor in the indicated direction + j = i + op[i] + l[i], l[j] = l[j], l[i] + op[i], op[j] = op[j], op[i] + # if it landed at the end or if the neighbor in the same + # direction is bigger then turn off op + if j == 0 or j == m or l[j + op[j]] > l[j]: + op[j] = 0 + # any element bigger to the left gets +1 op + for i in range(j): + if l[i] > l[j]: + op[i] = 1 + # any element bigger to the right gets -1 op + for i in range(j + 1, n): + if l[i] > l[j]: + op[i] = -1 + + +def generate_involutions(n): + """ + Generates involutions. + + An involution is a permutation that when multiplied + by itself equals the identity permutation. In this + implementation the involutions are generated using + Fixed Points. + + Alternatively, an involution can be considered as + a permutation that does not contain any cycles with + a length that is greater than two. + + Examples + ======== + + >>> from sympy.utilities.iterables import generate_involutions + >>> list(generate_involutions(3)) + [(0, 1, 2), (0, 2, 1), (1, 0, 2), (2, 1, 0)] + >>> len(list(generate_involutions(4))) + 10 + + References + ========== + + .. [1] https://mathworld.wolfram.com/PermutationInvolution.html + + """ + idx = list(range(n)) + for p in permutations(idx): + for i in idx: + if p[p[i]] != i: + break + else: + yield p + + +def multiset_derangements(s): + """Generate derangements of the elements of s *in place*. + + Examples + ======== + + >>> from sympy.utilities.iterables import multiset_derangements, uniq + + Because the derangements of multisets (not sets) are generated + in place, copies of the return value must be made if a collection + of derangements is desired or else all values will be the same: + + >>> list(uniq([i for i in multiset_derangements('1233')])) + [[None, None, None, None]] + >>> [i.copy() for i in multiset_derangements('1233')] + [['3', '3', '1', '2'], ['3', '3', '2', '1']] + >>> [''.join(i) for i in multiset_derangements('1233')] + ['3312', '3321'] + """ + from sympy.core.sorting import ordered + # create multiset dictionary of hashable elements or else + # remap elements to integers + try: + ms = multiset(s) + except TypeError: + # give each element a canonical integer value + key = dict(enumerate(ordered(uniq(s)))) + h = [] + for si in s: + for k in key: + if key[k] == si: + h.append(k) + break + for i in multiset_derangements(h): + yield [key[j] for j in i] + return + + mx = max(ms.values()) # max repetition of any element + n = len(s) # the number of elements + + ## special cases + + # 1) one element has more than half the total cardinality of s: no + # derangements are possible. + if mx*2 > n: + return + + # 2) all elements appear once: singletons + if len(ms) == n: + yield from _set_derangements(s) + return + + # find the first element that is repeated the most to place + # in the following two special cases where the selection + # is unambiguous: either there are two elements with multiplicity + # of mx or else there is only one with multiplicity mx + for M in ms: + if ms[M] == mx: + break + + inonM = [i for i in range(n) if s[i] != M] # location of non-M + iM = [i for i in range(n) if s[i] == M] # locations of M + rv = [None]*n + + # 3) half are the same + if 2*mx == n: + # M goes into non-M locations + for i in inonM: + rv[i] = M + # permutations of non-M go to M locations + for p in multiset_permutations([s[i] for i in inonM]): + for i, pi in zip(iM, p): + rv[i] = pi + yield rv + # clean-up (and encourages proper use of routine) + rv[:] = [None]*n + return + + # 4) single repeat covers all but 1 of the non-repeats: + # if there is one repeat then the multiset of the values + # of ms would be {mx: 1, 1: n - mx}, i.e. there would + # be n - mx + 1 values with the condition that n - 2*mx = 1 + if n - 2*mx == 1 and len(ms.values()) == n - mx + 1: + for i, i1 in enumerate(inonM): + ifill = inonM[:i] + inonM[i+1:] + for j in ifill: + rv[j] = M + for p in permutations([s[j] for j in ifill]): + rv[i1] = s[i1] + for j, pi in zip(iM, p): + rv[j] = pi + k = i1 + for j in iM: + rv[j], rv[k] = rv[k], rv[j] + yield rv + k = j + # clean-up (and encourages proper use of routine) + rv[:] = [None]*n + return + + ## general case is handled with 3 helpers: + # 1) `finish_derangements` will place the last two elements + # which have arbitrary multiplicities, e.g. for multiset + # {c: 3, a: 2, b: 2}, the last two elements are a and b + # 2) `iopen` will tell where a given element can be placed + # 3) `do` will recursively place elements into subsets of + # valid locations + + def finish_derangements(): + """Place the last two elements into the partially completed + derangement, and yield the results. + """ + + a = take[1][0] # penultimate element + a_ct = take[1][1] + b = take[0][0] # last element to be placed + b_ct = take[0][1] + + # split the indexes of the not-already-assigned elements of rv into + # three categories + forced_a = [] # positions which must have an a + forced_b = [] # positions which must have a b + open_free = [] # positions which could take either + for i in range(len(s)): + if rv[i] is None: + if s[i] == a: + forced_b.append(i) + elif s[i] == b: + forced_a.append(i) + else: + open_free.append(i) + + if len(forced_a) > a_ct or len(forced_b) > b_ct: + # No derangement possible + return + + for i in forced_a: + rv[i] = a + for i in forced_b: + rv[i] = b + for a_place in combinations(open_free, a_ct - len(forced_a)): + for a_pos in a_place: + rv[a_pos] = a + for i in open_free: + if rv[i] is None: # anything not in the subset is set to b + rv[i] = b + yield rv + # Clean up/undo the final placements + for i in open_free: + rv[i] = None + + # additional cleanup - clear forced_a, forced_b + for i in forced_a: + rv[i] = None + for i in forced_b: + rv[i] = None + + def iopen(v): + # return indices at which element v can be placed in rv: + # locations which are not already occupied if that location + # does not already contain v in the same location of s + return [i for i in range(n) if rv[i] is None and s[i] != v] + + def do(j): + if j == 1: + # handle the last two elements (regardless of multiplicity) + # with a special method + yield from finish_derangements() + else: + # place the mx elements of M into a subset of places + # into which it can be replaced + M, mx = take[j] + for i in combinations(iopen(M), mx): + # place M + for ii in i: + rv[ii] = M + # recursively place the next element + yield from do(j - 1) + # mark positions where M was placed as once again + # open for placement of other elements + for ii in i: + rv[ii] = None + + # process elements in order of canonically decreasing multiplicity + take = sorted(ms.items(), key=lambda x:(x[1], x[0])) + yield from do(len(take) - 1) + rv[:] = [None]*n + + +def random_derangement(t, choice=None, strict=True): + """Return a list of elements in which none are in the same positions + as they were originally. If an element fills more than half of the positions + then an error will be raised since no derangement is possible. To obtain + a derangement of as many items as possible--with some of the most numerous + remaining in their original positions--pass `strict=False`. To produce a + pseudorandom derangment, pass a pseudorandom selector like `choice` (see + below). + + Examples + ======== + + >>> from sympy.utilities.iterables import random_derangement + >>> t = 'SymPy: a CAS in pure Python' + >>> d = random_derangement(t) + >>> all(i != j for i, j in zip(d, t)) + True + + A predictable result can be obtained by using a pseudorandom + generator for the choice: + + >>> from sympy.core.random import seed, choice as c + >>> seed(1) + >>> d = [''.join(random_derangement(t, c)) for i in range(5)] + >>> assert len(set(d)) != 1 # we got different values + + By reseeding, the same sequence can be obtained: + + >>> seed(1) + >>> d2 = [''.join(random_derangement(t, c)) for i in range(5)] + >>> assert d == d2 + """ + if choice is None: + import secrets + choice = secrets.choice + def shuffle(rv): + '''Knuth shuffle''' + for i in range(len(rv) - 1, 0, -1): + x = choice(rv[:i + 1]) + j = rv.index(x) + rv[i], rv[j] = rv[j], rv[i] + def pick(rv, n): + '''shuffle rv and return the first n values + ''' + shuffle(rv) + return rv[:n] + ms = multiset(t) + tot = len(t) + ms = sorted(ms.items(), key=lambda x: x[1]) + # if there are not enough spaces for the most + # plentiful element to move to then some of them + # will have to stay in place + M, mx = ms[-1] + n = len(t) + xs = 2*mx - tot + if xs > 0: + if strict: + raise ValueError('no derangement possible') + opts = [i for (i, c) in enumerate(t) if c == ms[-1][0]] + pick(opts, xs) + stay = sorted(opts[:xs]) + rv = list(t) + for i in reversed(stay): + rv.pop(i) + rv = random_derangement(rv, choice) + for i in stay: + rv.insert(i, ms[-1][0]) + return ''.join(rv) if type(t) is str else rv + # the normal derangement calculated from here + if n == len(ms): + # approx 1/3 will succeed + rv = list(t) + while True: + shuffle(rv) + if all(i != j for i,j in zip(rv, t)): + break + else: + # general case + rv = [None]*n + while True: + j = 0 + while j > -len(ms): # do most numerous first + j -= 1 + e, c = ms[j] + opts = [i for i in range(n) if rv[i] is None and t[i] != e] + if len(opts) < c: + for i in range(n): + rv[i] = None + break # try again + pick(opts, c) + for i in range(c): + rv[opts[i]] = e + else: + return rv + return rv + + +def _set_derangements(s): + """ + yield derangements of items in ``s`` which are assumed to contain + no repeated elements + """ + if len(s) < 2: + return + if len(s) == 2: + yield [s[1], s[0]] + return + if len(s) == 3: + yield [s[1], s[2], s[0]] + yield [s[2], s[0], s[1]] + return + for p in permutations(s): + if not any(i == j for i, j in zip(p, s)): + yield list(p) + + +def generate_derangements(s): + """ + Return unique derangements of the elements of iterable ``s``. + + Examples + ======== + + >>> from sympy.utilities.iterables import generate_derangements + >>> list(generate_derangements([0, 1, 2])) + [[1, 2, 0], [2, 0, 1]] + >>> list(generate_derangements([0, 1, 2, 2])) + [[2, 2, 0, 1], [2, 2, 1, 0]] + >>> list(generate_derangements([0, 1, 1])) + [] + + See Also + ======== + + sympy.functions.combinatorial.factorials.subfactorial + + """ + if not has_dups(s): + yield from _set_derangements(s) + else: + for p in multiset_derangements(s): + yield list(p) + + +def necklaces(n, k, free=False): + """ + A routine to generate necklaces that may (free=True) or may not + (free=False) be turned over to be viewed. The "necklaces" returned + are comprised of ``n`` integers (beads) with ``k`` different + values (colors). Only unique necklaces are returned. + + Examples + ======== + + >>> from sympy.utilities.iterables import necklaces, bracelets + >>> def show(s, i): + ... return ''.join(s[j] for j in i) + + The "unrestricted necklace" is sometimes also referred to as a + "bracelet" (an object that can be turned over, a sequence that can + be reversed) and the term "necklace" is used to imply a sequence + that cannot be reversed. So ACB == ABC for a bracelet (rotate and + reverse) while the two are different for a necklace since rotation + alone cannot make the two sequences the same. + + (mnemonic: Bracelets can be viewed Backwards, but Not Necklaces.) + + >>> B = [show('ABC', i) for i in bracelets(3, 3)] + >>> N = [show('ABC', i) for i in necklaces(3, 3)] + >>> set(N) - set(B) + {'ACB'} + + >>> list(necklaces(4, 2)) + [(0, 0, 0, 0), (0, 0, 0, 1), (0, 0, 1, 1), + (0, 1, 0, 1), (0, 1, 1, 1), (1, 1, 1, 1)] + + >>> [show('.o', i) for i in bracelets(4, 2)] + ['....', '...o', '..oo', '.o.o', '.ooo', 'oooo'] + + References + ========== + + .. [1] https://mathworld.wolfram.com/Necklace.html + + .. [2] Frank Ruskey, Carla Savage, and Terry Min Yih Wang, + Generating necklaces, Journal of Algorithms 13 (1992), 414-430; + https://doi.org/10.1016/0196-6774(92)90047-G + + """ + # The FKM algorithm + if k == 0 and n > 0: + return + a = [0]*n + yield tuple(a) + if n == 0: + return + while True: + i = n - 1 + while a[i] == k - 1: + i -= 1 + if i == -1: + return + a[i] += 1 + for j in range(n - i - 1): + a[j + i + 1] = a[j] + if n % (i + 1) == 0 and (not free or all(a <= a[j::-1] + a[-1:j:-1] for j in range(n - 1))): + # No need to test j = n - 1. + yield tuple(a) + + +def bracelets(n, k): + """Wrapper to necklaces to return a free (unrestricted) necklace.""" + return necklaces(n, k, free=True) + + +def generate_oriented_forest(n): + """ + This algorithm generates oriented forests. + + An oriented graph is a directed graph having no symmetric pair of directed + edges. A forest is an acyclic graph, i.e., it has no cycles. A forest can + also be described as a disjoint union of trees, which are graphs in which + any two vertices are connected by exactly one simple path. + + Examples + ======== + + >>> from sympy.utilities.iterables import generate_oriented_forest + >>> list(generate_oriented_forest(4)) + [[0, 1, 2, 3], [0, 1, 2, 2], [0, 1, 2, 1], [0, 1, 2, 0], \ + [0, 1, 1, 1], [0, 1, 1, 0], [0, 1, 0, 1], [0, 1, 0, 0], [0, 0, 0, 0]] + + References + ========== + + .. [1] T. Beyer and S.M. Hedetniemi: constant time generation of + rooted trees, SIAM J. Computing Vol. 9, No. 4, November 1980 + + .. [2] https://stackoverflow.com/questions/1633833/oriented-forest-taocp-algorithm-in-python + + """ + P = list(range(-1, n)) + while True: + yield P[1:] + if P[n] > 0: + P[n] = P[P[n]] + else: + for p in range(n - 1, 0, -1): + if P[p] != 0: + target = P[p] - 1 + for q in range(p - 1, 0, -1): + if P[q] == target: + break + offset = p - q + for i in range(p, n + 1): + P[i] = P[i - offset] + break + else: + break + + +def minlex(seq, directed=True, key=None): + r""" + Return the rotation of the sequence in which the lexically smallest + elements appear first, e.g. `cba \rightarrow acb`. + + The sequence returned is a tuple, unless the input sequence is a string + in which case a string is returned. + + If ``directed`` is False then the smaller of the sequence and the + reversed sequence is returned, e.g. `cba \rightarrow abc`. + + If ``key`` is not None then it is used to extract a comparison key from each element in iterable. + + Examples + ======== + + >>> from sympy.combinatorics.polyhedron import minlex + >>> minlex((1, 2, 0)) + (0, 1, 2) + >>> minlex((1, 0, 2)) + (0, 2, 1) + >>> minlex((1, 0, 2), directed=False) + (0, 1, 2) + + >>> minlex('11010011000', directed=True) + '00011010011' + >>> minlex('11010011000', directed=False) + '00011001011' + + >>> minlex(('bb', 'aaa', 'c', 'a')) + ('a', 'bb', 'aaa', 'c') + >>> minlex(('bb', 'aaa', 'c', 'a'), key=len) + ('c', 'a', 'bb', 'aaa') + + """ + from sympy.functions.elementary.miscellaneous import Id + if key is None: key = Id + best = rotate_left(seq, least_rotation(seq, key=key)) + if not directed: + rseq = seq[::-1] + rbest = rotate_left(rseq, least_rotation(rseq, key=key)) + best = min(best, rbest, key=key) + + # Convert to tuple, unless we started with a string. + return tuple(best) if not isinstance(seq, str) else best + + +def runs(seq, op=gt): + """Group the sequence into lists in which successive elements + all compare the same with the comparison operator, ``op``: + op(seq[i + 1], seq[i]) is True from all elements in a run. + + Examples + ======== + + >>> from sympy.utilities.iterables import runs + >>> from operator import ge + >>> runs([0, 1, 2, 2, 1, 4, 3, 2, 2]) + [[0, 1, 2], [2], [1, 4], [3], [2], [2]] + >>> runs([0, 1, 2, 2, 1, 4, 3, 2, 2], op=ge) + [[0, 1, 2, 2], [1, 4], [3], [2, 2]] + """ + cycles = [] + seq = iter(seq) + try: + run = [next(seq)] + except StopIteration: + return [] + while True: + try: + ei = next(seq) + except StopIteration: + break + if op(ei, run[-1]): + run.append(ei) + continue + else: + cycles.append(run) + run = [ei] + if run: + cycles.append(run) + return cycles + + +def sequence_partitions(l, n, /): + r"""Returns the partition of sequence $l$ into $n$ bins + + Explanation + =========== + + Given the sequence $l_1 \cdots l_m \in V^+$ where + $V^+$ is the Kleene plus of $V$ + + The set of $n$ partitions of $l$ is defined as: + + .. math:: + \{(s_1, \cdots, s_n) | s_1 \in V^+, \cdots, s_n \in V^+, + s_1 \cdots s_n = l_1 \cdots l_m\} + + Parameters + ========== + + l : Sequence[T] + A nonempty sequence of any Python objects + + n : int + A positive integer + + Yields + ====== + + out : list[Sequence[T]] + A list of sequences with concatenation equals $l$. + This should conform with the type of $l$. + + Examples + ======== + + >>> from sympy.utilities.iterables import sequence_partitions + >>> for out in sequence_partitions([1, 2, 3, 4], 2): + ... print(out) + [[1], [2, 3, 4]] + [[1, 2], [3, 4]] + [[1, 2, 3], [4]] + + Notes + ===== + + This is modified version of EnricoGiampieri's partition generator + from https://stackoverflow.com/questions/13131491/partition-n-items-into-k-bins-in-python-lazily + + See Also + ======== + + sequence_partitions_empty + """ + # Asserting l is nonempty is done only for sanity check + if n == 1 and l: + yield [l] + return + for i in range(1, len(l)): + for part in sequence_partitions(l[i:], n - 1): + yield [l[:i]] + part + + +def sequence_partitions_empty(l, n, /): + r"""Returns the partition of sequence $l$ into $n$ bins with + empty sequence + + Explanation + =========== + + Given the sequence $l_1 \cdots l_m \in V^*$ where + $V^*$ is the Kleene star of $V$ + + The set of $n$ partitions of $l$ is defined as: + + .. math:: + \{(s_1, \cdots, s_n) | s_1 \in V^*, \cdots, s_n \in V^*, + s_1 \cdots s_n = l_1 \cdots l_m\} + + There are more combinations than :func:`sequence_partitions` because + empty sequence can fill everywhere, so we try to provide different + utility for this. + + Parameters + ========== + + l : Sequence[T] + A sequence of any Python objects (can be possibly empty) + + n : int + A positive integer + + Yields + ====== + + out : list[Sequence[T]] + A list of sequences with concatenation equals $l$. + This should conform with the type of $l$. + + Examples + ======== + + >>> from sympy.utilities.iterables import sequence_partitions_empty + >>> for out in sequence_partitions_empty([1, 2, 3, 4], 2): + ... print(out) + [[], [1, 2, 3, 4]] + [[1], [2, 3, 4]] + [[1, 2], [3, 4]] + [[1, 2, 3], [4]] + [[1, 2, 3, 4], []] + + See Also + ======== + + sequence_partitions + """ + if n < 1: + return + if n == 1: + yield [l] + return + for i in range(0, len(l) + 1): + for part in sequence_partitions_empty(l[i:], n - 1): + yield [l[:i]] + part + + +def kbins(l, k, ordered=None): + """ + Return sequence ``l`` partitioned into ``k`` bins. + + Examples + ======== + + The default is to give the items in the same order, but grouped + into k partitions without any reordering: + + >>> from sympy.utilities.iterables import kbins + >>> for p in kbins(list(range(5)), 2): + ... print(p) + ... + [[0], [1, 2, 3, 4]] + [[0, 1], [2, 3, 4]] + [[0, 1, 2], [3, 4]] + [[0, 1, 2, 3], [4]] + + The ``ordered`` flag is either None (to give the simple partition + of the elements) or is a 2 digit integer indicating whether the order of + the bins and the order of the items in the bins matters. Given:: + + A = [[0], [1, 2]] + B = [[1, 2], [0]] + C = [[2, 1], [0]] + D = [[0], [2, 1]] + + the following values for ``ordered`` have the shown meanings:: + + 00 means A == B == C == D + 01 means A == B + 10 means A == D + 11 means A == A + + >>> for ordered_flag in [None, 0, 1, 10, 11]: + ... print('ordered = %s' % ordered_flag) + ... for p in kbins(list(range(3)), 2, ordered=ordered_flag): + ... print(' %s' % p) + ... + ordered = None + [[0], [1, 2]] + [[0, 1], [2]] + ordered = 0 + [[0, 1], [2]] + [[0, 2], [1]] + [[0], [1, 2]] + ordered = 1 + [[0], [1, 2]] + [[0], [2, 1]] + [[1], [0, 2]] + [[1], [2, 0]] + [[2], [0, 1]] + [[2], [1, 0]] + ordered = 10 + [[0, 1], [2]] + [[2], [0, 1]] + [[0, 2], [1]] + [[1], [0, 2]] + [[0], [1, 2]] + [[1, 2], [0]] + ordered = 11 + [[0], [1, 2]] + [[0, 1], [2]] + [[0], [2, 1]] + [[0, 2], [1]] + [[1], [0, 2]] + [[1, 0], [2]] + [[1], [2, 0]] + [[1, 2], [0]] + [[2], [0, 1]] + [[2, 0], [1]] + [[2], [1, 0]] + [[2, 1], [0]] + + See Also + ======== + + partitions, multiset_partitions + + """ + if ordered is None: + yield from sequence_partitions(l, k) + elif ordered == 11: + for pl in multiset_permutations(l): + pl = list(pl) + yield from sequence_partitions(pl, k) + elif ordered == 00: + yield from multiset_partitions(l, k) + elif ordered == 10: + for p in multiset_partitions(l, k): + for perm in permutations(p): + yield list(perm) + elif ordered == 1: + for kgot, p in partitions(len(l), k, size=True): + if kgot != k: + continue + for li in multiset_permutations(l): + rv = [] + i = j = 0 + li = list(li) + for size, multiplicity in sorted(p.items()): + for m in range(multiplicity): + j = i + size + rv.append(li[i: j]) + i = j + yield rv + else: + raise ValueError( + 'ordered must be one of 00, 01, 10 or 11, not %s' % ordered) + + +def permute_signs(t): + """Return iterator in which the signs of non-zero elements + of t are permuted. + + Examples + ======== + + >>> from sympy.utilities.iterables import permute_signs + >>> list(permute_signs((0, 1, 2))) + [(0, 1, 2), (0, -1, 2), (0, 1, -2), (0, -1, -2)] + """ + for signs in product(*[(1, -1)]*(len(t) - t.count(0))): + signs = list(signs) + yield type(t)([i*signs.pop() if i else i for i in t]) + + +def signed_permutations(t): + """Return iterator in which the signs of non-zero elements + of t and the order of the elements are permuted and all + returned values are unique. + + Examples + ======== + + >>> from sympy.utilities.iterables import signed_permutations + >>> list(signed_permutations((0, 1, 2))) + [(0, 1, 2), (0, -1, 2), (0, 1, -2), (0, -1, -2), (0, 2, 1), + (0, -2, 1), (0, 2, -1), (0, -2, -1), (1, 0, 2), (-1, 0, 2), + (1, 0, -2), (-1, 0, -2), (1, 2, 0), (-1, 2, 0), (1, -2, 0), + (-1, -2, 0), (2, 0, 1), (-2, 0, 1), (2, 0, -1), (-2, 0, -1), + (2, 1, 0), (-2, 1, 0), (2, -1, 0), (-2, -1, 0)] + """ + return (type(t)(i) for j in multiset_permutations(t) + for i in permute_signs(j)) + + +def rotations(s, dir=1): + """Return a generator giving the items in s as list where + each subsequent list has the items rotated to the left (default) + or right (``dir=-1``) relative to the previous list. + + Examples + ======== + + >>> from sympy import rotations + >>> list(rotations([1,2,3])) + [[1, 2, 3], [2, 3, 1], [3, 1, 2]] + >>> list(rotations([1,2,3], -1)) + [[1, 2, 3], [3, 1, 2], [2, 3, 1]] + """ + seq = list(s) + for i in range(len(seq)): + yield seq + seq = rotate_left(seq, dir) + + +def roundrobin(*iterables): + """roundrobin recipe taken from itertools documentation: + https://docs.python.org/3/library/itertools.html#itertools-recipes + + roundrobin('ABC', 'D', 'EF') --> A D E B F C + + Recipe credited to George Sakkis + """ + nexts = cycle(iter(it).__next__ for it in iterables) + + pending = len(iterables) + while pending: + try: + for nxt in nexts: + yield nxt() + except StopIteration: + pending -= 1 + nexts = cycle(islice(nexts, pending)) + + + +class NotIterable: + """ + Use this as mixin when creating a class which is not supposed to + return true when iterable() is called on its instances because + calling list() on the instance, for example, would result in + an infinite loop. + """ + pass + + +def iterable(i, exclude=(str, dict, NotIterable)): + """ + Return a boolean indicating whether ``i`` is SymPy iterable. + True also indicates that the iterator is finite, e.g. you can + call list(...) on the instance. + + When SymPy is working with iterables, it is almost always assuming + that the iterable is not a string or a mapping, so those are excluded + by default. If you want a pure Python definition, make exclude=None. To + exclude multiple items, pass them as a tuple. + + You can also set the _iterable attribute to True or False on your class, + which will override the checks here, including the exclude test. + + As a rule of thumb, some SymPy functions use this to check if they should + recursively map over an object. If an object is technically iterable in + the Python sense but does not desire this behavior (e.g., because its + iteration is not finite, or because iteration might induce an unwanted + computation), it should disable it by setting the _iterable attribute to False. + + See also: is_sequence + + Examples + ======== + + >>> from sympy.utilities.iterables import iterable + >>> from sympy import Tuple + >>> things = [[1], (1,), set([1]), Tuple(1), (j for j in [1, 2]), {1:2}, '1', 1] + >>> for i in things: + ... print('%s %s' % (iterable(i), type(i))) + True <... 'list'> + True <... 'tuple'> + True <... 'set'> + True + True <... 'generator'> + False <... 'dict'> + False <... 'str'> + False <... 'int'> + + >>> iterable({}, exclude=None) + True + >>> iterable({}, exclude=str) + True + >>> iterable("no", exclude=str) + False + + """ + if hasattr(i, '_iterable'): + return i._iterable + try: + iter(i) + except TypeError: + return False + if exclude: + return not isinstance(i, exclude) + return True + + +def is_sequence(i, include=None): + """ + Return a boolean indicating whether ``i`` is a sequence in the SymPy + sense. If anything that fails the test below should be included as + being a sequence for your application, set 'include' to that object's + type; multiple types should be passed as a tuple of types. + + Note: although generators can generate a sequence, they often need special + handling to make sure their elements are captured before the generator is + exhausted, so these are not included by default in the definition of a + sequence. + + See also: iterable + + Examples + ======== + + >>> from sympy.utilities.iterables import is_sequence + >>> from types import GeneratorType + >>> is_sequence([]) + True + >>> is_sequence(set()) + False + >>> is_sequence('abc') + False + >>> is_sequence('abc', include=str) + True + >>> generator = (c for c in 'abc') + >>> is_sequence(generator) + False + >>> is_sequence(generator, include=(str, GeneratorType)) + True + + """ + return (hasattr(i, '__getitem__') and + iterable(i) or + bool(include) and + isinstance(i, include)) + + +@deprecated( + """ + Using postorder_traversal from the sympy.utilities.iterables submodule is + deprecated. + + Instead, use postorder_traversal from the top-level sympy namespace, like + + sympy.postorder_traversal + """, + deprecated_since_version="1.10", + active_deprecations_target="deprecated-traversal-functions-moved") +def postorder_traversal(node, keys=None): + from sympy.core.traversal import postorder_traversal as _postorder_traversal + return _postorder_traversal(node, keys=keys) + + +@deprecated( + """ + Using interactive_traversal from the sympy.utilities.iterables submodule + is deprecated. + + Instead, use interactive_traversal from the top-level sympy namespace, + like + + sympy.interactive_traversal + """, + deprecated_since_version="1.10", + active_deprecations_target="deprecated-traversal-functions-moved") +def interactive_traversal(expr): + from sympy.interactive.traversal import interactive_traversal as _interactive_traversal + return _interactive_traversal(expr) + + +@deprecated( + """ + Importing default_sort_key from sympy.utilities.iterables is deprecated. + Use from sympy import default_sort_key instead. + """, + deprecated_since_version="1.10", +active_deprecations_target="deprecated-sympy-core-compatibility", +) +def default_sort_key(*args, **kwargs): + from sympy import default_sort_key as _default_sort_key + return _default_sort_key(*args, **kwargs) + + +@deprecated( + """ + Importing default_sort_key from sympy.utilities.iterables is deprecated. + Use from sympy import default_sort_key instead. + """, + deprecated_since_version="1.10", +active_deprecations_target="deprecated-sympy-core-compatibility", +) +def ordered(*args, **kwargs): + from sympy import ordered as _ordered + return _ordered(*args, **kwargs) diff --git a/.venv/lib/python3.13/site-packages/sympy/utilities/lambdify.py b/.venv/lib/python3.13/site-packages/sympy/utilities/lambdify.py new file mode 100644 index 0000000000000000000000000000000000000000..6807fdaec10963a60ba88d6e1ec6b2951c19241e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/utilities/lambdify.py @@ -0,0 +1,1592 @@ +""" +This module provides convenient functions to transform SymPy expressions to +lambda functions which can be used to calculate numerical values very fast. +""" + +from __future__ import annotations +from typing import Any + +import builtins +import inspect +import keyword +import textwrap +import linecache +import weakref + +# Required despite static analysis claiming it is not used +from sympy.external import import_module # noqa:F401 +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.decorator import doctest_depends_on +from sympy.utilities.iterables import (is_sequence, iterable, + NotIterable, flatten) +from sympy.utilities.misc import filldedent + + +__doctest_requires__ = {('lambdify',): ['numpy', 'tensorflow']} + + +# Default namespaces, letting us define translations that can't be defined +# by simple variable maps, like I => 1j +MATH_DEFAULT: dict[str, Any] = {} +CMATH_DEFAULT: dict[str,Any] = {} +MPMATH_DEFAULT: dict[str, Any] = {} +NUMPY_DEFAULT: dict[str, Any] = {"I": 1j} +SCIPY_DEFAULT: dict[str, Any] = {"I": 1j} +CUPY_DEFAULT: dict[str, Any] = {"I": 1j} +JAX_DEFAULT: dict[str, Any] = {"I": 1j} +TENSORFLOW_DEFAULT: dict[str, Any] = {} +TORCH_DEFAULT: dict[str, Any] = {"I": 1j} +SYMPY_DEFAULT: dict[str, Any] = {} +NUMEXPR_DEFAULT: dict[str, Any] = {} + +# These are the namespaces the lambda functions will use. +# These are separate from the names above because they are modified +# throughout this file, whereas the defaults should remain unmodified. + +MATH = MATH_DEFAULT.copy() +CMATH = CMATH_DEFAULT.copy() +MPMATH = MPMATH_DEFAULT.copy() +NUMPY = NUMPY_DEFAULT.copy() +SCIPY = SCIPY_DEFAULT.copy() +CUPY = CUPY_DEFAULT.copy() +JAX = JAX_DEFAULT.copy() +TENSORFLOW = TENSORFLOW_DEFAULT.copy() +TORCH = TORCH_DEFAULT.copy() +SYMPY = SYMPY_DEFAULT.copy() +NUMEXPR = NUMEXPR_DEFAULT.copy() + + +# Mappings between SymPy and other modules function names. +MATH_TRANSLATIONS = { + "ceiling": "ceil", + "E": "e", + "ln": "log", +} + +CMATH_TRANSLATIONS: dict[str, str] = {} + +# NOTE: This dictionary is reused in Function._eval_evalf to allow subclasses +# of Function to automatically evalf. +MPMATH_TRANSLATIONS = { + "Abs": "fabs", + "elliptic_k": "ellipk", + "elliptic_f": "ellipf", + "elliptic_e": "ellipe", + "elliptic_pi": "ellippi", + "ceiling": "ceil", + "chebyshevt": "chebyt", + "chebyshevu": "chebyu", + "assoc_legendre": "legenp", + "E": "e", + "I": "j", + "ln": "log", + #"lowergamma":"lower_gamma", + "oo": "inf", + #"uppergamma":"upper_gamma", + "LambertW": "lambertw", + "MutableDenseMatrix": "matrix", + "ImmutableDenseMatrix": "matrix", + "conjugate": "conj", + "dirichlet_eta": "altzeta", + "Ei": "ei", + "Shi": "shi", + "Chi": "chi", + "Si": "si", + "Ci": "ci", + "RisingFactorial": "rf", + "FallingFactorial": "ff", + "betainc_regularized": "betainc", +} + +NUMPY_TRANSLATIONS: dict[str, str] = { + "Heaviside": "heaviside", +} +SCIPY_TRANSLATIONS: dict[str, str] = { + "jn" : "spherical_jn", + "yn" : "spherical_yn" +} +CUPY_TRANSLATIONS: dict[str, str] = {} +JAX_TRANSLATIONS: dict[str, str] = {} + +TENSORFLOW_TRANSLATIONS: dict[str, str] = {} +TORCH_TRANSLATIONS: dict[str, str] = {} + +NUMEXPR_TRANSLATIONS: dict[str, str] = {} + +# Available modules: +MODULES = { + "math": (MATH, MATH_DEFAULT, MATH_TRANSLATIONS, ("from math import *",)), + "cmath": (CMATH, CMATH_DEFAULT, CMATH_TRANSLATIONS, ("import cmath; from cmath import *",)), + "mpmath": (MPMATH, MPMATH_DEFAULT, MPMATH_TRANSLATIONS, ("from mpmath import *",)), + "numpy": (NUMPY, NUMPY_DEFAULT, NUMPY_TRANSLATIONS, ("import numpy; from numpy import *; from numpy.linalg import *",)), + "scipy": (SCIPY, SCIPY_DEFAULT, SCIPY_TRANSLATIONS, ("import scipy; import numpy; from scipy.special import *",)), + "cupy": (CUPY, CUPY_DEFAULT, CUPY_TRANSLATIONS, ("import cupy",)), + "jax": (JAX, JAX_DEFAULT, JAX_TRANSLATIONS, ("import jax",)), + "tensorflow": (TENSORFLOW, TENSORFLOW_DEFAULT, TENSORFLOW_TRANSLATIONS, ("import tensorflow",)), + "torch": (TORCH, TORCH_DEFAULT, TORCH_TRANSLATIONS, ("import torch",)), + "sympy": (SYMPY, SYMPY_DEFAULT, {}, ( + "from sympy.functions import *", + "from sympy.matrices import *", + "from sympy import Integral, pi, oo, nan, zoo, E, I",)), + "numexpr" : (NUMEXPR, NUMEXPR_DEFAULT, NUMEXPR_TRANSLATIONS, + ("import_module('numexpr')", )), +} + + +def _import(module, reload=False): + """ + Creates a global translation dictionary for module. + + The argument module has to be one of the following strings: "math","cmath" + "mpmath", "numpy", "sympy", "tensorflow", "jax". + These dictionaries map names of Python functions to their equivalent in + other modules. + """ + try: + namespace, namespace_default, translations, import_commands = MODULES[ + module] + except KeyError: + raise NameError( + "'%s' module cannot be used for lambdification" % module) + + # Clear namespace or exit + if namespace != namespace_default: + # The namespace was already generated, don't do it again if not forced. + if reload: + namespace.clear() + namespace.update(namespace_default) + else: + return + + for import_command in import_commands: + if import_command.startswith('import_module'): + module = eval(import_command) + + if module is not None: + namespace.update(module.__dict__) + continue + else: + try: + exec(import_command, {}, namespace) + continue + except ImportError: + pass + + raise ImportError( + "Cannot import '%s' with '%s' command" % (module, import_command)) + + # Add translated names to namespace + for sympyname, translation in translations.items(): + namespace[sympyname] = namespace[translation] + + # For computing the modulus of a SymPy expression we use the builtin abs + # function, instead of the previously used fabs function for all + # translation modules. This is because the fabs function in the math + # module does not accept complex valued arguments. (see issue 9474). The + # only exception, where we don't use the builtin abs function is the + # mpmath translation module, because mpmath.fabs returns mpf objects in + # contrast to abs(). + if 'Abs' not in namespace: + namespace['Abs'] = abs + +# Used for dynamically generated filenames that are inserted into the +# linecache. +_lambdify_generated_counter = 1 + + +@doctest_depends_on(modules=('numpy', 'scipy', 'tensorflow',), python_version=(3,)) +def lambdify(args, expr, modules=None, printer=None, use_imps=True, + dummify=False, cse=False, docstring_limit=1000): + """Convert a SymPy expression into a function that allows for fast + numeric evaluation. + + .. warning:: + This function uses ``exec``, and thus should not be used on + unsanitized input. + + .. deprecated:: 1.7 + Passing a set for the *args* parameter is deprecated as sets are + unordered. Use an ordered iterable such as a list or tuple. + + Explanation + =========== + + For example, to convert the SymPy expression ``sin(x) + cos(x)`` to an + equivalent NumPy function that numerically evaluates it: + + >>> from sympy import sin, cos, symbols, lambdify + >>> import numpy as np + >>> x = symbols('x') + >>> expr = sin(x) + cos(x) + >>> expr + sin(x) + cos(x) + >>> f = lambdify(x, expr, 'numpy') + >>> a = np.array([1, 2]) + >>> f(a) + [1.38177329 0.49315059] + + The primary purpose of this function is to provide a bridge from SymPy + expressions to numerical libraries such as NumPy, SciPy, NumExpr, mpmath, + and tensorflow. In general, SymPy functions do not work with objects from + other libraries, such as NumPy arrays, and functions from numeric + libraries like NumPy or mpmath do not work on SymPy expressions. + ``lambdify`` bridges the two by converting a SymPy expression to an + equivalent numeric function. + + The basic workflow with ``lambdify`` is to first create a SymPy expression + representing whatever mathematical function you wish to evaluate. This + should be done using only SymPy functions and expressions. Then, use + ``lambdify`` to convert this to an equivalent function for numerical + evaluation. For instance, above we created ``expr`` using the SymPy symbol + ``x`` and SymPy functions ``sin`` and ``cos``, then converted it to an + equivalent NumPy function ``f``, and called it on a NumPy array ``a``. + + Parameters + ========== + + args : List[Symbol] + A variable or a list of variables whose nesting represents the + nesting of the arguments that will be passed to the function. + + Variables can be symbols, undefined functions, or matrix symbols. + + >>> from sympy import Eq + >>> from sympy.abc import x, y, z + + The list of variables should match the structure of how the + arguments will be passed to the function. Simply enclose the + parameters as they will be passed in a list. + + To call a function like ``f(x)`` then ``[x]`` + should be the first argument to ``lambdify``; for this + case a single ``x`` can also be used: + + >>> f = lambdify(x, x + 1) + >>> f(1) + 2 + >>> f = lambdify([x], x + 1) + >>> f(1) + 2 + + To call a function like ``f(x, y)`` then ``[x, y]`` will + be the first argument of the ``lambdify``: + + >>> f = lambdify([x, y], x + y) + >>> f(1, 1) + 2 + + To call a function with a single 3-element tuple like + ``f((x, y, z))`` then ``[(x, y, z)]`` will be the first + argument of the ``lambdify``: + + >>> f = lambdify([(x, y, z)], Eq(z**2, x**2 + y**2)) + >>> f((3, 4, 5)) + True + + If two args will be passed and the first is a scalar but + the second is a tuple with two arguments then the items + in the list should match that structure: + + >>> f = lambdify([x, (y, z)], x + y + z) + >>> f(1, (2, 3)) + 6 + + expr : Expr + An expression, list of expressions, or matrix to be evaluated. + + Lists may be nested. + If the expression is a list, the output will also be a list. + + >>> f = lambdify(x, [x, [x + 1, x + 2]]) + >>> f(1) + [1, [2, 3]] + + If it is a matrix, an array will be returned (for the NumPy module). + + >>> from sympy import Matrix + >>> f = lambdify(x, Matrix([x, x + 1])) + >>> f(1) + [[1] + [2]] + + Note that the argument order here (variables then expression) is used + to emulate the Python ``lambda`` keyword. ``lambdify(x, expr)`` works + (roughly) like ``lambda x: expr`` + (see :ref:`lambdify-how-it-works` below). + + modules : str, optional + Specifies the numeric library to use. + + If not specified, *modules* defaults to: + + - ``["scipy", "numpy"]`` if SciPy is installed + - ``["numpy"]`` if only NumPy is installed + - ``["math","cmath", "mpmath", "sympy"]`` if neither is installed. + + That is, SymPy functions are replaced as far as possible by + either ``scipy`` or ``numpy`` functions if available, and Python's + standard library ``math`` and ``cmath``, or ``mpmath`` functions otherwise. + + *modules* can be one of the following types: + + - The strings ``"math"``, ``"cmath"``, ``"mpmath"``, ``"numpy"``, ``"numexpr"``, + ``"scipy"``, ``"sympy"``, or ``"tensorflow"`` or ``"jax"``. This uses the + corresponding printer and namespace mapping for that module. + - A module (e.g., ``math``). This uses the global namespace of the + module. If the module is one of the above known modules, it will + also use the corresponding printer and namespace mapping + (i.e., ``modules=numpy`` is equivalent to ``modules="numpy"``). + - A dictionary that maps names of SymPy functions to arbitrary + functions + (e.g., ``{'sin': custom_sin}``). + - A list that contains a mix of the arguments above, with higher + priority given to entries appearing first + (e.g., to use the NumPy module but override the ``sin`` function + with a custom version, you can use + ``[{'sin': custom_sin}, 'numpy']``). + + dummify : bool, optional + Whether or not the variables in the provided expression that are not + valid Python identifiers are substituted with dummy symbols. + + This allows for undefined functions like ``Function('f')(t)`` to be + supplied as arguments. By default, the variables are only dummified + if they are not valid Python identifiers. + + Set ``dummify=True`` to replace all arguments with dummy symbols + (if ``args`` is not a string) - for example, to ensure that the + arguments do not redefine any built-in names. + + cse : bool, or callable, optional + Large expressions can be computed more efficiently when + common subexpressions are identified and precomputed before + being used multiple time. Finding the subexpressions will make + creation of the 'lambdify' function slower, however. + + When ``True``, ``sympy.simplify.cse`` is used, otherwise (the default) + the user may pass a function matching the ``cse`` signature. + + docstring_limit : int or None + When lambdifying large expressions, a significant proportion of the time + spent inside ``lambdify`` is spent producing a string representation of + the expression for use in the automatically generated docstring of the + returned function. For expressions containing hundreds or more nodes the + resulting docstring often becomes so long and dense that it is difficult + to read. To reduce the runtime of lambdify, the rendering of the full + expression inside the docstring can be disabled. + + When ``None``, the full expression is rendered in the docstring. When + ``0`` or a negative ``int``, an ellipsis is rendering in the docstring + instead of the expression. When a strictly positive ``int``, if the + number of nodes in the expression exceeds ``docstring_limit`` an + ellipsis is rendered in the docstring, otherwise a string representation + of the expression is rendered as normal. The default is ``1000``. + + Examples + ======== + + >>> from sympy.utilities.lambdify import implemented_function + >>> from sympy import sqrt, sin, Matrix + >>> from sympy import Function + >>> from sympy.abc import w, x, y, z + + >>> f = lambdify(x, x**2) + >>> f(2) + 4 + >>> f = lambdify((x, y, z), [z, y, x]) + >>> f(1,2,3) + [3, 2, 1] + >>> f = lambdify(x, sqrt(x)) + >>> f(4) + 2.0 + >>> f = lambdify((x, y), sin(x*y)**2) + >>> f(0, 5) + 0.0 + >>> row = lambdify((x, y), Matrix((x, x + y)).T, modules='sympy') + >>> row(1, 2) + Matrix([[1, 3]]) + + ``lambdify`` can be used to translate SymPy expressions into mpmath + functions. This may be preferable to using ``evalf`` (which uses mpmath on + the backend) in some cases. + + >>> f = lambdify(x, sin(x), 'mpmath') + >>> f(1) + 0.8414709848078965 + + Tuple arguments are handled and the lambdified function should + be called with the same type of arguments as were used to create + the function: + + >>> f = lambdify((x, (y, z)), x + y) + >>> f(1, (2, 4)) + 3 + + The ``flatten`` function can be used to always work with flattened + arguments: + + >>> from sympy.utilities.iterables import flatten + >>> args = w, (x, (y, z)) + >>> vals = 1, (2, (3, 4)) + >>> f = lambdify(flatten(args), w + x + y + z) + >>> f(*flatten(vals)) + 10 + + Functions present in ``expr`` can also carry their own numerical + implementations, in a callable attached to the ``_imp_`` attribute. This + can be used with undefined functions using the ``implemented_function`` + factory: + + >>> f = implemented_function(Function('f'), lambda x: x+1) + >>> func = lambdify(x, f(x)) + >>> func(4) + 5 + + ``lambdify`` always prefers ``_imp_`` implementations to implementations + in other namespaces, unless the ``use_imps`` input parameter is False. + + Usage with Tensorflow: + + >>> import tensorflow as tf + >>> from sympy import Max, sin, lambdify + >>> from sympy.abc import x + + >>> f = Max(x, sin(x)) + >>> func = lambdify(x, f, 'tensorflow') + + After tensorflow v2, eager execution is enabled by default. + If you want to get the compatible result across tensorflow v1 and v2 + as same as this tutorial, run this line. + + >>> tf.compat.v1.enable_eager_execution() + + If you have eager execution enabled, you can get the result out + immediately as you can use numpy. + + If you pass tensorflow objects, you may get an ``EagerTensor`` + object instead of value. + + >>> result = func(tf.constant(1.0)) + >>> print(result) + tf.Tensor(1.0, shape=(), dtype=float32) + >>> print(result.__class__) + + + You can use ``.numpy()`` to get the numpy value of the tensor. + + >>> result.numpy() + 1.0 + + >>> var = tf.Variable(2.0) + >>> result = func(var) # also works for tf.Variable and tf.Placeholder + >>> result.numpy() + 2.0 + + And it works with any shape array. + + >>> tensor = tf.constant([[1.0, 2.0], [3.0, 4.0]]) + >>> result = func(tensor) + >>> result.numpy() + [[1. 2.] + [3. 4.]] + + Notes + ===== + + - For functions involving large array calculations, numexpr can provide a + significant speedup over numpy. Please note that the available functions + for numexpr are more limited than numpy but can be expanded with + ``implemented_function`` and user defined subclasses of Function. If + specified, numexpr may be the only option in modules. The official list + of numexpr functions can be found at: + https://numexpr.readthedocs.io/en/latest/user_guide.html#supported-functions + + - In the above examples, the generated functions can accept scalar + values or numpy arrays as arguments. However, in some cases + the generated function relies on the input being a numpy array: + + >>> import numpy + >>> from sympy import Piecewise + >>> from sympy.testing.pytest import ignore_warnings + >>> f = lambdify(x, Piecewise((x, x <= 1), (1/x, x > 1)), "numpy") + + >>> with ignore_warnings(RuntimeWarning): + ... f(numpy.array([-1, 0, 1, 2])) + [-1. 0. 1. 0.5] + + >>> f(0) + Traceback (most recent call last): + ... + ZeroDivisionError: division by zero + + In such cases, the input should be wrapped in a numpy array: + + >>> with ignore_warnings(RuntimeWarning): + ... float(f(numpy.array([0]))) + 0.0 + + Or if numpy functionality is not required another module can be used: + + >>> f = lambdify(x, Piecewise((x, x <= 1), (1/x, x > 1)), "math") + >>> f(0) + 0 + + .. _lambdify-how-it-works: + + How it works + ============ + + When using this function, it helps a great deal to have an idea of what it + is doing. At its core, lambdify is nothing more than a namespace + translation, on top of a special printer that makes some corner cases work + properly. + + To understand lambdify, first we must properly understand how Python + namespaces work. Say we had two files. One called ``sin_cos_sympy.py``, + with + + .. code:: python + + # sin_cos_sympy.py + + from sympy.functions.elementary.trigonometric import (cos, sin) + + def sin_cos(x): + return sin(x) + cos(x) + + + and one called ``sin_cos_numpy.py`` with + + .. code:: python + + # sin_cos_numpy.py + + from numpy import sin, cos + + def sin_cos(x): + return sin(x) + cos(x) + + The two files define an identical function ``sin_cos``. However, in the + first file, ``sin`` and ``cos`` are defined as the SymPy ``sin`` and + ``cos``. In the second, they are defined as the NumPy versions. + + If we were to import the first file and use the ``sin_cos`` function, we + would get something like + + >>> from sin_cos_sympy import sin_cos # doctest: +SKIP + >>> sin_cos(1) # doctest: +SKIP + cos(1) + sin(1) + + On the other hand, if we imported ``sin_cos`` from the second file, we + would get + + >>> from sin_cos_numpy import sin_cos # doctest: +SKIP + >>> sin_cos(1) # doctest: +SKIP + 1.38177329068 + + In the first case we got a symbolic output, because it used the symbolic + ``sin`` and ``cos`` functions from SymPy. In the second, we got a numeric + result, because ``sin_cos`` used the numeric ``sin`` and ``cos`` functions + from NumPy. But notice that the versions of ``sin`` and ``cos`` that were + used was not inherent to the ``sin_cos`` function definition. Both + ``sin_cos`` definitions are exactly the same. Rather, it was based on the + names defined at the module where the ``sin_cos`` function was defined. + + The key point here is that when function in Python references a name that + is not defined in the function, that name is looked up in the "global" + namespace of the module where that function is defined. + + Now, in Python, we can emulate this behavior without actually writing a + file to disk using the ``exec`` function. ``exec`` takes a string + containing a block of Python code, and a dictionary that should contain + the global variables of the module. It then executes the code "in" that + dictionary, as if it were the module globals. The following is equivalent + to the ``sin_cos`` defined in ``sin_cos_sympy.py``: + + >>> import sympy + >>> module_dictionary = {'sin': sympy.sin, 'cos': sympy.cos} + >>> exec(''' + ... def sin_cos(x): + ... return sin(x) + cos(x) + ... ''', module_dictionary) + >>> sin_cos = module_dictionary['sin_cos'] + >>> sin_cos(1) + cos(1) + sin(1) + + and similarly with ``sin_cos_numpy``: + + >>> import numpy + >>> module_dictionary = {'sin': numpy.sin, 'cos': numpy.cos} + >>> exec(''' + ... def sin_cos(x): + ... return sin(x) + cos(x) + ... ''', module_dictionary) + >>> sin_cos = module_dictionary['sin_cos'] + >>> sin_cos(1) + 1.38177329068 + + So now we can get an idea of how ``lambdify`` works. The name "lambdify" + comes from the fact that we can think of something like ``lambdify(x, + sin(x) + cos(x), 'numpy')`` as ``lambda x: sin(x) + cos(x)``, where + ``sin`` and ``cos`` come from the ``numpy`` namespace. This is also why + the symbols argument is first in ``lambdify``, as opposed to most SymPy + functions where it comes after the expression: to better mimic the + ``lambda`` keyword. + + ``lambdify`` takes the input expression (like ``sin(x) + cos(x)``) and + + 1. Converts it to a string + 2. Creates a module globals dictionary based on the modules that are + passed in (by default, it uses the NumPy module) + 3. Creates the string ``"def func({vars}): return {expr}"``, where ``{vars}`` is the + list of variables separated by commas, and ``{expr}`` is the string + created in step 1., then ``exec``s that string with the module globals + namespace and returns ``func``. + + In fact, functions returned by ``lambdify`` support inspection. So you can + see exactly how they are defined by using ``inspect.getsource``, or ``??`` if you + are using IPython or the Jupyter notebook. + + >>> f = lambdify(x, sin(x) + cos(x)) + >>> import inspect + >>> print(inspect.getsource(f)) + def _lambdifygenerated(x): + return sin(x) + cos(x) + + This shows us the source code of the function, but not the namespace it + was defined in. We can inspect that by looking at the ``__globals__`` + attribute of ``f``: + + >>> f.__globals__['sin'] + + >>> f.__globals__['cos'] + + >>> f.__globals__['sin'] is numpy.sin + True + + This shows us that ``sin`` and ``cos`` in the namespace of ``f`` will be + ``numpy.sin`` and ``numpy.cos``. + + Note that there are some convenience layers in each of these steps, but at + the core, this is how ``lambdify`` works. Step 1 is done using the + ``LambdaPrinter`` printers defined in the printing module (see + :mod:`sympy.printing.lambdarepr`). This allows different SymPy expressions + to define how they should be converted to a string for different modules. + You can change which printer ``lambdify`` uses by passing a custom printer + in to the ``printer`` argument. + + Step 2 is augmented by certain translations. There are default + translations for each module, but you can provide your own by passing a + list to the ``modules`` argument. For instance, + + >>> def mysin(x): + ... print('taking the sin of', x) + ... return numpy.sin(x) + ... + >>> f = lambdify(x, sin(x), [{'sin': mysin}, 'numpy']) + >>> f(1) + taking the sin of 1 + 0.8414709848078965 + + The globals dictionary is generated from the list by merging the + dictionary ``{'sin': mysin}`` and the module dictionary for NumPy. The + merging is done so that earlier items take precedence, which is why + ``mysin`` is used above instead of ``numpy.sin``. + + If you want to modify the way ``lambdify`` works for a given function, it + is usually easiest to do so by modifying the globals dictionary as such. + In more complicated cases, it may be necessary to create and pass in a + custom printer. + + Finally, step 3 is augmented with certain convenience operations, such as + the addition of a docstring. + + Understanding how ``lambdify`` works can make it easier to avoid certain + gotchas when using it. For instance, a common mistake is to create a + lambdified function for one module (say, NumPy), and pass it objects from + another (say, a SymPy expression). + + For instance, say we create + + >>> from sympy.abc import x + >>> f = lambdify(x, x + 1, 'numpy') + + Now if we pass in a NumPy array, we get that array plus 1 + + >>> import numpy + >>> a = numpy.array([1, 2]) + >>> f(a) + [2 3] + + But what happens if you make the mistake of passing in a SymPy expression + instead of a NumPy array: + + >>> f(x + 1) + x + 2 + + This worked, but it was only by accident. Now take a different lambdified + function: + + >>> from sympy import sin + >>> g = lambdify(x, x + sin(x), 'numpy') + + This works as expected on NumPy arrays: + + >>> g(a) + [1.84147098 2.90929743] + + But if we try to pass in a SymPy expression, it fails + + >>> g(x + 1) + Traceback (most recent call last): + ... + TypeError: loop of ufunc does not support argument 0 of type Add which has + no callable sin method + + Now, let's look at what happened. The reason this fails is that ``g`` + calls ``numpy.sin`` on the input expression, and ``numpy.sin`` does not + know how to operate on a SymPy object. **As a general rule, NumPy + functions do not know how to operate on SymPy expressions, and SymPy + functions do not know how to operate on NumPy arrays. This is why lambdify + exists: to provide a bridge between SymPy and NumPy.** + + However, why is it that ``f`` did work? That's because ``f`` does not call + any functions, it only adds 1. So the resulting function that is created, + ``def _lambdifygenerated(x): return x + 1`` does not depend on the globals + namespace it is defined in. Thus it works, but only by accident. A future + version of ``lambdify`` may remove this behavior. + + Be aware that certain implementation details described here may change in + future versions of SymPy. The API of passing in custom modules and + printers will not change, but the details of how a lambda function is + created may change. However, the basic idea will remain the same, and + understanding it will be helpful to understanding the behavior of + lambdify. + + **In general: you should create lambdified functions for one module (say, + NumPy), and only pass it input types that are compatible with that module + (say, NumPy arrays).** Remember that by default, if the ``module`` + argument is not provided, ``lambdify`` creates functions using the NumPy + and SciPy namespaces. + """ + from sympy.core.symbol import Symbol + from sympy.core.expr import Expr + + # If the user hasn't specified any modules, use what is available. + if modules is None: + try: + _import("scipy") + except ImportError: + try: + _import("numpy") + except ImportError: + # Use either numpy (if available) or python.math where possible. + # XXX: This leads to different behaviour on different systems and + # might be the reason for irreproducible errors. + modules = ["math", "mpmath", "sympy"] + else: + modules = ["numpy"] + else: + modules = ["numpy", "scipy"] + + # Get the needed namespaces. + namespaces = [] + # First find any function implementations + if use_imps: + namespaces.append(_imp_namespace(expr)) + # Check for dict before iterating + if isinstance(modules, (dict, str)) or not hasattr(modules, '__iter__'): + namespaces.append(modules) + else: + # consistency check + if _module_present('numexpr', modules) and len(modules) > 1: + raise TypeError("numexpr must be the only item in 'modules'") + namespaces += list(modules) + # fill namespace with first having highest priority + namespace = {} + for m in namespaces[::-1]: + buf = _get_namespace(m) + namespace.update(buf) + + if hasattr(expr, "atoms"): + #Try if you can extract symbols from the expression. + #Move on if expr.atoms in not implemented. + syms = expr.atoms(Symbol) + for term in syms: + namespace.update({str(term): term}) + + if printer is None: + if _module_present('mpmath', namespaces): + from sympy.printing.pycode import MpmathPrinter as Printer # type: ignore + elif _module_present('scipy', namespaces): + from sympy.printing.numpy import SciPyPrinter as Printer # type: ignore + elif _module_present('numpy', namespaces): + from sympy.printing.numpy import NumPyPrinter as Printer # type: ignore + elif _module_present('cupy', namespaces): + from sympy.printing.numpy import CuPyPrinter as Printer # type: ignore + elif _module_present('jax', namespaces): + from sympy.printing.numpy import JaxPrinter as Printer # type: ignore + elif _module_present('numexpr', namespaces): + from sympy.printing.lambdarepr import NumExprPrinter as Printer # type: ignore + elif _module_present('tensorflow', namespaces): + from sympy.printing.tensorflow import TensorflowPrinter as Printer # type: ignore + elif _module_present('torch', namespaces): + from sympy.printing.pytorch import TorchPrinter as Printer # type: ignore + elif _module_present('sympy', namespaces): + from sympy.printing.pycode import SymPyPrinter as Printer # type: ignore + elif _module_present('cmath', namespaces): + from sympy.printing.pycode import CmathPrinter as Printer # type: ignore + else: + from sympy.printing.pycode import PythonCodePrinter as Printer # type: ignore + user_functions = {} + for m in namespaces[::-1]: + if isinstance(m, dict): + for k in m: + user_functions[k] = k + printer = Printer({'fully_qualified_modules': False, 'inline': True, + 'allow_unknown_functions': True, + 'user_functions': user_functions}) + + if isinstance(args, set): + sympy_deprecation_warning( + """ +Passing the function arguments to lambdify() as a set is deprecated. This +leads to unpredictable results since sets are unordered. Instead, use a list +or tuple for the function arguments. + """, + deprecated_since_version="1.6.3", + active_deprecations_target="deprecated-lambdify-arguments-set", + ) + + # Get the names of the args, for creating a docstring + iterable_args = (args,) if isinstance(args, Expr) else args + names = [] + + # Grab the callers frame, for getting the names by inspection (if needed) + callers_local_vars = inspect.currentframe().f_back.f_locals.items() # type: ignore + for n, var in enumerate(iterable_args): + if hasattr(var, 'name'): + names.append(var.name) + else: + # It's an iterable. Try to get name by inspection of calling frame. + name_list = [var_name for var_name, var_val in callers_local_vars + if var_val is var] + if len(name_list) == 1: + names.append(name_list[0]) + else: + # Cannot infer name with certainty. arg_# will have to do. + names.append('arg_' + str(n)) + + # Create the function definition code and execute it + funcname = '_lambdifygenerated' + if _module_present('tensorflow', namespaces): + funcprinter = _TensorflowEvaluatorPrinter(printer, dummify) + else: + funcprinter = _EvaluatorPrinter(printer, dummify) + + if cse == True: + from sympy.simplify.cse_main import cse as _cse + cses, _expr = _cse(expr, list=False) + elif callable(cse): + cses, _expr = cse(expr) + else: + cses, _expr = (), expr + funcstr = funcprinter.doprint(funcname, iterable_args, _expr, cses=cses) + + # Collect the module imports from the code printers. + imp_mod_lines = [] + for mod, keys in (getattr(printer, 'module_imports', None) or {}).items(): + for k in keys: + if k not in namespace: + ln = "from %s import %s" % (mod, k) + try: + exec(ln, {}, namespace) + except ImportError: + # Tensorflow 2.0 has issues with importing a specific + # function from its submodule. + # https://github.com/tensorflow/tensorflow/issues/33022 + ln = "%s = %s.%s" % (k, mod, k) + exec(ln, {}, namespace) + imp_mod_lines.append(ln) + + # Provide lambda expression with builtins, and compatible implementation of range + namespace.update({'builtins':builtins, 'range':range}) + + funclocals = {} + global _lambdify_generated_counter + filename = '' % _lambdify_generated_counter + _lambdify_generated_counter += 1 + c = compile(funcstr, filename, 'exec') + exec(c, namespace, funclocals) + # mtime has to be None or else linecache.checkcache will remove it + linecache.cache[filename] = (len(funcstr), None, funcstr.splitlines(True), filename) # type: ignore + + # Remove the entry from the linecache when the object is garbage collected + def cleanup_linecache(filename): + def _cleanup(): + if filename in linecache.cache: + del linecache.cache[filename] + return _cleanup + + func = funclocals[funcname] + + weakref.finalize(func, cleanup_linecache(filename)) + + # Apply the docstring + sig = "func({})".format(", ".join(str(i) for i in names)) + sig = textwrap.fill(sig, subsequent_indent=' '*8) + if _too_large_for_docstring(expr, docstring_limit): + expr_str = "EXPRESSION REDACTED DUE TO LENGTH, (see lambdify's `docstring_limit`)" + src_str = "SOURCE CODE REDACTED DUE TO LENGTH, (see lambdify's `docstring_limit`)" + else: + expr_str = str(expr) + if len(expr_str) > 78: + expr_str = textwrap.wrap(expr_str, 75)[0] + '...' + src_str = funcstr + func.__doc__ = ( + "Created with lambdify. Signature:\n\n" + "{sig}\n\n" + "Expression:\n\n" + "{expr}\n\n" + "Source code:\n\n" + "{src}\n\n" + "Imported modules:\n\n" + "{imp_mods}" + ).format(sig=sig, expr=expr_str, src=src_str, imp_mods='\n'.join(imp_mod_lines)) + return func + +def _module_present(modname, modlist): + if modname in modlist: + return True + for m in modlist: + if hasattr(m, '__name__') and m.__name__ == modname: + return True + return False + +def _get_namespace(m): + """ + This is used by _lambdify to parse its arguments. + """ + if isinstance(m, str): + _import(m) + return MODULES[m][0] + elif isinstance(m, dict): + return m + elif hasattr(m, "__dict__"): + return m.__dict__ + else: + raise TypeError("Argument must be either a string, dict or module but it is: %s" % m) + + +def _recursive_to_string(doprint, arg): + """Functions in lambdify accept both SymPy types and non-SymPy types such as python + lists and tuples. This method ensures that we only call the doprint method of the + printer with SymPy types (so that the printer safely can use SymPy-methods).""" + from sympy.matrices.matrixbase import MatrixBase + from sympy.core.basic import Basic + + if isinstance(arg, (Basic, MatrixBase)): + return doprint(arg) + elif iterable(arg): + if isinstance(arg, list): + left, right = "[", "]" + elif isinstance(arg, tuple): + left, right = "(", ",)" + if not arg: + return "()" + else: + raise NotImplementedError("unhandled type: %s, %s" % (type(arg), arg)) + return left +', '.join(_recursive_to_string(doprint, e) for e in arg) + right + elif isinstance(arg, str): + return arg + else: + return doprint(arg) + + +def lambdastr(args, expr, printer=None, dummify=None): + """ + Returns a string that can be evaluated to a lambda function. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy.utilities.lambdify import lambdastr + >>> lambdastr(x, x**2) + 'lambda x: (x**2)' + >>> lambdastr((x,y,z), [z,y,x]) + 'lambda x,y,z: ([z, y, x])' + + Although tuples may not appear as arguments to lambda in Python 3, + lambdastr will create a lambda function that will unpack the original + arguments so that nested arguments can be handled: + + >>> lambdastr((x, (y, z)), x + y) + 'lambda _0,_1: (lambda x,y,z: (x + y))(_0,_1[0],_1[1])' + """ + # Transforming everything to strings. + from sympy.matrices import DeferredVector + from sympy.core.basic import Basic + from sympy.core.function import (Derivative, Function) + from sympy.core.symbol import (Dummy, Symbol) + from sympy.core.sympify import sympify + + if printer is not None: + if inspect.isfunction(printer): + lambdarepr = printer + else: + if inspect.isclass(printer): + lambdarepr = lambda expr: printer().doprint(expr) + else: + lambdarepr = lambda expr: printer.doprint(expr) + else: + #XXX: This has to be done here because of circular imports + from sympy.printing.lambdarepr import lambdarepr + + def sub_args(args, dummies_dict): + if isinstance(args, str): + return args + elif isinstance(args, DeferredVector): + return str(args) + elif iterable(args): + dummies = flatten([sub_args(a, dummies_dict) for a in args]) + return ",".join(str(a) for a in dummies) + else: + # replace these with Dummy symbols + if isinstance(args, (Function, Symbol, Derivative)): + dummies = Dummy() + dummies_dict.update({args : dummies}) + return str(dummies) + else: + return str(args) + + def sub_expr(expr, dummies_dict): + expr = sympify(expr) + # dict/tuple are sympified to Basic + if isinstance(expr, Basic): + expr = expr.xreplace(dummies_dict) + # list is not sympified to Basic + elif isinstance(expr, list): + expr = [sub_expr(a, dummies_dict) for a in expr] + return expr + + # Transform args + def isiter(l): + return iterable(l, exclude=(str, DeferredVector, NotIterable)) + + def flat_indexes(iterable): + n = 0 + + for el in iterable: + if isiter(el): + for ndeep in flat_indexes(el): + yield (n,) + ndeep + else: + yield (n,) + + n += 1 + + if dummify is None: + dummify = any(isinstance(a, Basic) and + a.atoms(Function, Derivative) for a in ( + args if isiter(args) else [args])) + + if isiter(args) and any(isiter(i) for i in args): + dum_args = [str(Dummy(str(i))) for i in range(len(args))] + + indexed_args = ','.join([ + dum_args[ind[0]] + ''.join(["[%s]" % k for k in ind[1:]]) + for ind in flat_indexes(args)]) + + lstr = lambdastr(flatten(args), expr, printer=printer, dummify=dummify) + + return 'lambda %s: (%s)(%s)' % (','.join(dum_args), lstr, indexed_args) + + dummies_dict = {} + if dummify: + args = sub_args(args, dummies_dict) + else: + if isinstance(args, str): + pass + elif iterable(args, exclude=DeferredVector): + args = ",".join(str(a) for a in args) + + # Transform expr + if dummify: + if isinstance(expr, str): + pass + else: + expr = sub_expr(expr, dummies_dict) + expr = _recursive_to_string(lambdarepr, expr) + return "lambda %s: (%s)" % (args, expr) + +class _EvaluatorPrinter: + def __init__(self, printer=None, dummify=False): + self._dummify = dummify + + #XXX: This has to be done here because of circular imports + from sympy.printing.lambdarepr import LambdaPrinter + + if printer is None: + printer = LambdaPrinter() + + if inspect.isfunction(printer): + self._exprrepr = printer + else: + if inspect.isclass(printer): + printer = printer() + + self._exprrepr = printer.doprint + + #if hasattr(printer, '_print_Symbol'): + # symbolrepr = printer._print_Symbol + + #if hasattr(printer, '_print_Dummy'): + # dummyrepr = printer._print_Dummy + + # Used to print the generated function arguments in a standard way + self._argrepr = LambdaPrinter().doprint + + def doprint(self, funcname, args, expr, *, cses=()): + """ + Returns the function definition code as a string. + """ + from sympy.core.symbol import Dummy + + funcbody = [] + + if not iterable(args): + args = [args] + + if cses: + cses = list(cses) + subvars, subexprs = zip(*cses) + exprs = [expr] + list(subexprs) + argstrs, exprs = self._preprocess(args, exprs, cses=cses) + expr, subexprs = exprs[0], exprs[1:] + cses = zip(subvars, subexprs) + else: + argstrs, expr = self._preprocess(args, expr) + + # Generate argument unpacking and final argument list + funcargs = [] + unpackings = [] + + for argstr in argstrs: + if iterable(argstr): + funcargs.append(self._argrepr(Dummy())) + unpackings.extend(self._print_unpacking(argstr, funcargs[-1])) + else: + funcargs.append(argstr) + + funcsig = 'def {}({}):'.format(funcname, ', '.join(funcargs)) + + # Wrap input arguments before unpacking + funcbody.extend(self._print_funcargwrapping(funcargs)) + + funcbody.extend(unpackings) + + for s, e in cses: + if e is None: + funcbody.append('del {}'.format(self._exprrepr(s))) + else: + funcbody.append('{} = {}'.format(self._exprrepr(s), self._exprrepr(e))) + + # Subs may appear in expressions generated by .diff() + subs_assignments = [] + expr = self._handle_Subs(expr, out=subs_assignments) + for lhs, rhs in subs_assignments: + funcbody.append('{} = {}'.format(self._exprrepr(lhs), self._exprrepr(rhs))) + + str_expr = _recursive_to_string(self._exprrepr, expr) + + if '\n' in str_expr: + str_expr = '({})'.format(str_expr) + funcbody.append('return {}'.format(str_expr)) + + funclines = [funcsig] + funclines.extend([' ' + line for line in funcbody]) + + return '\n'.join(funclines) + '\n' + + @classmethod + def _is_safe_ident(cls, ident): + return isinstance(ident, str) and ident.isidentifier() \ + and not keyword.iskeyword(ident) + + def _preprocess(self, args, expr, cses=(), _dummies_dict=None): + """Preprocess args, expr to replace arguments that do not map + to valid Python identifiers. + + Returns string form of args, and updated expr. + """ + from sympy.core.basic import Basic + from sympy.core.sorting import ordered + from sympy.core.function import (Derivative, Function) + from sympy.core.symbol import Dummy, uniquely_named_symbol + from sympy.matrices import DeferredVector + from sympy.core.expr import Expr + + # Args of type Dummy can cause name collisions with args + # of type Symbol. Force dummify of everything in this + # situation. + dummify = self._dummify or any( + isinstance(arg, Dummy) for arg in flatten(args)) + + argstrs = [None]*len(args) + if _dummies_dict is None: + _dummies_dict = {} + + def update_dummies(arg, dummy): + _dummies_dict[arg] = dummy + for repl, sub in cses: + arg = arg.xreplace({sub: repl}) + _dummies_dict[arg] = dummy + + for arg, i in reversed(list(ordered(zip(args, range(len(args)))))): + if iterable(arg): + s, expr = self._preprocess(arg, expr, cses=cses, _dummies_dict=_dummies_dict) + elif isinstance(arg, DeferredVector): + s = str(arg) + elif isinstance(arg, Basic) and arg.is_symbol: + s = str(arg) + if dummify or not self._is_safe_ident(s): + dummy = Dummy() + if isinstance(expr, Expr): + dummy = uniquely_named_symbol( + dummy.name, expr, modify=lambda s: '_' + s) + s = self._argrepr(dummy) + update_dummies(arg, dummy) + expr = self._subexpr(expr, _dummies_dict) + elif dummify or isinstance(arg, (Function, Derivative)): + dummy = Dummy() + s = self._argrepr(dummy) + update_dummies(arg, dummy) + expr = self._subexpr(expr, _dummies_dict) + else: + s = str(arg) + argstrs[i] = s + return argstrs, expr + + def _subexpr(self, expr, dummies_dict): + from sympy.matrices import DeferredVector + from sympy.core.sympify import sympify + + expr = sympify(expr) + xreplace = getattr(expr, 'xreplace', None) + if xreplace is not None: + expr = xreplace(dummies_dict) + else: + if isinstance(expr, DeferredVector): + pass + elif isinstance(expr, dict): + k = [self._subexpr(sympify(a), dummies_dict) for a in expr.keys()] + v = [self._subexpr(sympify(a), dummies_dict) for a in expr.values()] + expr = dict(zip(k, v)) + elif isinstance(expr, tuple): + expr = tuple(self._subexpr(sympify(a), dummies_dict) for a in expr) + elif isinstance(expr, list): + expr = [self._subexpr(sympify(a), dummies_dict) for a in expr] + return expr + + def _print_funcargwrapping(self, args): + """Generate argument wrapping code. + + args is the argument list of the generated function (strings). + + Return value is a list of lines of code that will be inserted at + the beginning of the function definition. + """ + return [] + + def _print_unpacking(self, unpackto, arg): + """Generate argument unpacking code. + + arg is the function argument to be unpacked (a string), and + unpackto is a list or nested lists of the variable names (strings) to + unpack to. + """ + def unpack_lhs(lvalues): + return '[{}]'.format(', '.join( + unpack_lhs(val) if iterable(val) else val for val in lvalues)) + + return ['{} = {}'.format(unpack_lhs(unpackto), arg)] + + def _handle_Subs(self, expr, out): + """Any instance of Subs is extracted and returned as assignment pairs.""" + from sympy.core.basic import Basic + from sympy.core.function import Subs + from sympy.core.symbol import Dummy + from sympy.matrices.matrixbase import MatrixBase + + def _replace(ex, variables, point): + safe = {} + for lhs, rhs in zip(variables, point): + dummy = Dummy() + safe[lhs] = dummy + out.append((dummy, rhs)) + return ex.xreplace(safe) + + if isinstance(expr, (Basic, MatrixBase)): + expr = expr.replace(Subs, _replace) + elif iterable(expr): + expr = type(expr)([self._handle_Subs(e, out) for e in expr]) + return expr + +class _TensorflowEvaluatorPrinter(_EvaluatorPrinter): + def _print_unpacking(self, lvalues, rvalue): + """Generate argument unpacking code. + + This method is used when the input value is not iterable, + but can be indexed (see issue #14655). + """ + + def flat_indexes(elems): + n = 0 + + for el in elems: + if iterable(el): + for ndeep in flat_indexes(el): + yield (n,) + ndeep + else: + yield (n,) + + n += 1 + + indexed = ', '.join('{}[{}]'.format(rvalue, ']['.join(map(str, ind))) + for ind in flat_indexes(lvalues)) + + return ['[{}] = [{}]'.format(', '.join(flatten(lvalues)), indexed)] + +def _imp_namespace(expr, namespace=None): + """ Return namespace dict with function implementations + + We need to search for functions in anything that can be thrown at + us - that is - anything that could be passed as ``expr``. Examples + include SymPy expressions, as well as tuples, lists and dicts that may + contain SymPy expressions. + + Parameters + ---------- + expr : object + Something passed to lambdify, that will generate valid code from + ``str(expr)``. + namespace : None or mapping + Namespace to fill. None results in new empty dict + + Returns + ------- + namespace : dict + dict with keys of implemented function names within ``expr`` and + corresponding values being the numerical implementation of + function + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy.utilities.lambdify import implemented_function, _imp_namespace + >>> from sympy import Function + >>> f = implemented_function(Function('f'), lambda x: x+1) + >>> g = implemented_function(Function('g'), lambda x: x*10) + >>> namespace = _imp_namespace(f(g(x))) + >>> sorted(namespace.keys()) + ['f', 'g'] + """ + # Delayed import to avoid circular imports + from sympy.core.function import FunctionClass + if namespace is None: + namespace = {} + # tuples, lists, dicts are valid expressions + if is_sequence(expr): + for arg in expr: + _imp_namespace(arg, namespace) + return namespace + elif isinstance(expr, dict): + for key, val in expr.items(): + # functions can be in dictionary keys + _imp_namespace(key, namespace) + _imp_namespace(val, namespace) + return namespace + # SymPy expressions may be Functions themselves + func = getattr(expr, 'func', None) + if isinstance(func, FunctionClass): + imp = getattr(func, '_imp_', None) + if imp is not None: + name = expr.func.__name__ + if name in namespace and namespace[name] != imp: + raise ValueError('We found more than one ' + 'implementation with name ' + '"%s"' % name) + namespace[name] = imp + # and / or they may take Functions as arguments + if hasattr(expr, 'args'): + for arg in expr.args: + _imp_namespace(arg, namespace) + return namespace + + +def implemented_function(symfunc, implementation): + """ Add numerical ``implementation`` to function ``symfunc``. + + ``symfunc`` can be an ``UndefinedFunction`` instance, or a name string. + In the latter case we create an ``UndefinedFunction`` instance with that + name. + + Be aware that this is a quick workaround, not a general method to create + special symbolic functions. If you want to create a symbolic function to be + used by all the machinery of SymPy you should subclass the ``Function`` + class. + + Parameters + ---------- + symfunc : ``str`` or ``UndefinedFunction`` instance + If ``str``, then create new ``UndefinedFunction`` with this as + name. If ``symfunc`` is an Undefined function, create a new function + with the same name and the implemented function attached. + implementation : callable + numerical implementation to be called by ``evalf()`` or ``lambdify`` + + Returns + ------- + afunc : sympy.FunctionClass instance + function with attached implementation + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy.utilities.lambdify import implemented_function + >>> from sympy import lambdify + >>> f = implemented_function('f', lambda x: x+1) + >>> lam_f = lambdify(x, f(x)) + >>> lam_f(4) + 5 + """ + # Delayed import to avoid circular imports + from sympy.core.function import UndefinedFunction + # if name, create function to hold implementation + kwargs = {} + if isinstance(symfunc, UndefinedFunction): + kwargs = symfunc._kwargs + symfunc = symfunc.__name__ + if isinstance(symfunc, str): + # Keyword arguments to UndefinedFunction are added as attributes to + # the created class. + symfunc = UndefinedFunction( + symfunc, _imp_=staticmethod(implementation), **kwargs) + elif not isinstance(symfunc, UndefinedFunction): + raise ValueError(filldedent(''' + symfunc should be either a string or + an UndefinedFunction instance.''')) + return symfunc + + +def _too_large_for_docstring(expr, limit): + """Decide whether an ``Expr`` is too large to be fully rendered in a + ``lambdify`` docstring. + + This is a fast alternative to ``count_ops``, which can become prohibitively + slow for large expressions, because in this instance we only care whether + ``limit`` is exceeded rather than counting the exact number of nodes in the + expression. + + Parameters + ========== + expr : ``Expr``, (nested) ``list`` of ``Expr``, or ``Matrix`` + The same objects that can be passed to the ``expr`` argument of + ``lambdify``. + limit : ``int`` or ``None`` + The threshold above which an expression contains too many nodes to be + usefully rendered in the docstring. If ``None`` then there is no limit. + + Returns + ======= + bool + ``True`` if the number of nodes in the expression exceeds the limit, + ``False`` otherwise. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy.utilities.lambdify import _too_large_for_docstring + >>> expr = x + >>> _too_large_for_docstring(expr, None) + False + >>> _too_large_for_docstring(expr, 100) + False + >>> _too_large_for_docstring(expr, 1) + False + >>> _too_large_for_docstring(expr, 0) + True + >>> _too_large_for_docstring(expr, -1) + True + + Does this split it? + + >>> expr = [x, y, z] + >>> _too_large_for_docstring(expr, None) + False + >>> _too_large_for_docstring(expr, 100) + False + >>> _too_large_for_docstring(expr, 1) + True + >>> _too_large_for_docstring(expr, 0) + True + >>> _too_large_for_docstring(expr, -1) + True + + >>> expr = [x, [y], z, [[x+y], [x*y*z, [x+y+z]]]] + >>> _too_large_for_docstring(expr, None) + False + >>> _too_large_for_docstring(expr, 100) + False + >>> _too_large_for_docstring(expr, 1) + True + >>> _too_large_for_docstring(expr, 0) + True + >>> _too_large_for_docstring(expr, -1) + True + + >>> expr = ((x + y + z)**5).expand() + >>> _too_large_for_docstring(expr, None) + False + >>> _too_large_for_docstring(expr, 100) + True + >>> _too_large_for_docstring(expr, 1) + True + >>> _too_large_for_docstring(expr, 0) + True + >>> _too_large_for_docstring(expr, -1) + True + + >>> from sympy import Matrix + >>> expr = Matrix([[(x + y + z), ((x + y + z)**2).expand(), + ... ((x + y + z)**3).expand(), ((x + y + z)**4).expand()]]) + >>> _too_large_for_docstring(expr, None) + False + >>> _too_large_for_docstring(expr, 1000) + False + >>> _too_large_for_docstring(expr, 100) + True + >>> _too_large_for_docstring(expr, 1) + True + >>> _too_large_for_docstring(expr, 0) + True + >>> _too_large_for_docstring(expr, -1) + True + + """ + # Must be imported here to avoid a circular import error + from sympy.core.traversal import postorder_traversal + + if limit is None: + return False + + i = 0 + for _ in postorder_traversal(expr): + i += 1 + if i > limit: + return True + return False diff --git a/.venv/lib/python3.13/site-packages/sympy/utilities/magic.py b/.venv/lib/python3.13/site-packages/sympy/utilities/magic.py new file mode 100644 index 0000000000000000000000000000000000000000..e853a0ad9a85bc252dcb24e8a1ecbfca422ac3fd --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/utilities/magic.py @@ -0,0 +1,12 @@ +"""Functions that involve magic. """ + +def pollute(names, objects): + """Pollute the global namespace with symbols -> objects mapping. """ + from inspect import currentframe + frame = currentframe().f_back.f_back + + try: + for name, obj in zip(names, objects): + frame.f_globals[name] = obj + finally: + del frame # break cyclic dependencies as stated in inspect docs diff --git a/.venv/lib/python3.13/site-packages/sympy/utilities/matchpy_connector.py b/.venv/lib/python3.13/site-packages/sympy/utilities/matchpy_connector.py new file mode 100644 index 0000000000000000000000000000000000000000..35aa013294b93bbcfe4a1bf4ec96b629ea5a468f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/utilities/matchpy_connector.py @@ -0,0 +1,340 @@ +""" +The objects in this module allow the usage of the MatchPy pattern matching +library on SymPy expressions. +""" +import re +from typing import List, Callable, NamedTuple, Any, Dict + +from sympy.core.sympify import _sympify +from sympy.external import import_module +from sympy.functions import (log, sin, cos, tan, cot, csc, sec, erf, gamma, uppergamma) +from sympy.functions.elementary.hyperbolic import acosh, asinh, atanh, acoth, acsch, asech, cosh, sinh, tanh, coth, sech, csch +from sympy.functions.elementary.trigonometric import atan, acsc, asin, acot, acos, asec +from sympy.functions.special.error_functions import fresnelc, fresnels, erfc, erfi, Ei +from sympy.core.add import Add +from sympy.core.basic import Basic +from sympy.core.expr import Expr +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.relational import (Equality, Unequality) +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import exp +from sympy.integrals.integrals import Integral +from sympy.printing.repr import srepr +from sympy.utilities.decorator import doctest_depends_on + + +matchpy = import_module("matchpy") + + +__doctest_requires__ = {('*',): ['matchpy']} + + +if matchpy: + from matchpy import Operation, CommutativeOperation, AssociativeOperation, OneIdentityOperation + from matchpy.expressions.functions import op_iter, create_operation_expression, op_len + + Operation.register(Integral) + Operation.register(Pow) + OneIdentityOperation.register(Pow) + + Operation.register(Add) + OneIdentityOperation.register(Add) + CommutativeOperation.register(Add) + AssociativeOperation.register(Add) + + Operation.register(Mul) + OneIdentityOperation.register(Mul) + CommutativeOperation.register(Mul) + AssociativeOperation.register(Mul) + + Operation.register(Equality) + CommutativeOperation.register(Equality) + Operation.register(Unequality) + CommutativeOperation.register(Unequality) + + Operation.register(exp) + Operation.register(log) + Operation.register(gamma) + Operation.register(uppergamma) + Operation.register(fresnels) + Operation.register(fresnelc) + Operation.register(erf) + Operation.register(Ei) + Operation.register(erfc) + Operation.register(erfi) + Operation.register(sin) + Operation.register(cos) + Operation.register(tan) + Operation.register(cot) + Operation.register(csc) + Operation.register(sec) + Operation.register(sinh) + Operation.register(cosh) + Operation.register(tanh) + Operation.register(coth) + Operation.register(csch) + Operation.register(sech) + Operation.register(asin) + Operation.register(acos) + Operation.register(atan) + Operation.register(acot) + Operation.register(acsc) + Operation.register(asec) + Operation.register(asinh) + Operation.register(acosh) + Operation.register(atanh) + Operation.register(acoth) + Operation.register(acsch) + Operation.register(asech) + + @op_iter.register(Integral) # type: ignore + def _(operation): + return iter((operation._args[0],) + operation._args[1]) + + @op_iter.register(Basic) # type: ignore + def _(operation): + return iter(operation._args) + + @op_len.register(Integral) # type: ignore + def _(operation): + return 1 + len(operation._args[1]) + + @op_len.register(Basic) # type: ignore + def _(operation): + return len(operation._args) + + @create_operation_expression.register(Basic) + def sympy_op_factory(old_operation, new_operands, variable_name=True): + return type(old_operation)(*new_operands) + + +if matchpy: + from matchpy import Wildcard +else: + class Wildcard: # type: ignore + def __init__(self, min_length, fixed_size, variable_name, optional): + self.min_count = min_length + self.fixed_size = fixed_size + self.variable_name = variable_name + self.optional = optional + + +@doctest_depends_on(modules=('matchpy',)) +class _WildAbstract(Wildcard, Symbol): + min_length: int # abstract field required in subclasses + fixed_size: bool # abstract field required in subclasses + + def __init__(self, variable_name=None, optional=None, **assumptions): + min_length = self.min_length + fixed_size = self.fixed_size + if optional is not None: + optional = _sympify(optional) + Wildcard.__init__(self, min_length, fixed_size, str(variable_name), optional) + + def __getstate__(self): + return { + "min_length": self.min_length, + "fixed_size": self.fixed_size, + "min_count": self.min_count, + "variable_name": self.variable_name, + "optional": self.optional, + } + + def __new__(cls, variable_name=None, optional=None, **assumptions): + cls._sanitize(assumptions, cls) + return _WildAbstract.__xnew__(cls, variable_name, optional, **assumptions) + + def __getnewargs__(self): + return self.variable_name, self.optional + + @staticmethod + def __xnew__(cls, variable_name=None, optional=None, **assumptions): + obj = Symbol.__xnew__(cls, variable_name, **assumptions) + return obj + + def _hashable_content(self): + if self.optional: + return super()._hashable_content() + (self.min_count, self.fixed_size, self.variable_name, self.optional) + else: + return super()._hashable_content() + (self.min_count, self.fixed_size, self.variable_name) + + def __copy__(self) -> '_WildAbstract': + return type(self)(variable_name=self.variable_name, optional=self.optional) + + def __repr__(self): + return str(self) + + def __str__(self): + return self.name + + +@doctest_depends_on(modules=('matchpy',)) +class WildDot(_WildAbstract): + min_length = 1 + fixed_size = True + + +@doctest_depends_on(modules=('matchpy',)) +class WildPlus(_WildAbstract): + min_length = 1 + fixed_size = False + + +@doctest_depends_on(modules=('matchpy',)) +class WildStar(_WildAbstract): + min_length = 0 + fixed_size = False + + +def _get_srepr(expr): + s = srepr(expr) + s = re.sub(r"WildDot\('(\w+)'\)", r"\1", s) + s = re.sub(r"WildPlus\('(\w+)'\)", r"*\1", s) + s = re.sub(r"WildStar\('(\w+)'\)", r"*\1", s) + return s + + +class ReplacementInfo(NamedTuple): + replacement: Any + info: Any + + +@doctest_depends_on(modules=('matchpy',)) +class Replacer: + """ + Replacer object to perform multiple pattern matching and subexpression + replacements in SymPy expressions. + + Examples + ======== + + Example to construct a simple first degree equation solver: + + >>> from sympy.utilities.matchpy_connector import WildDot, Replacer + >>> from sympy import Equality, Symbol + >>> x = Symbol("x") + >>> a_ = WildDot("a_", optional=1) + >>> b_ = WildDot("b_", optional=0) + + The lines above have defined two wildcards, ``a_`` and ``b_``, the + coefficients of the equation `a x + b = 0`. The optional values specified + indicate which expression to return in case no match is found, they are + necessary in equations like `a x = 0` and `x + b = 0`. + + Create two constraints to make sure that ``a_`` and ``b_`` will not match + any expression containing ``x``: + + >>> from matchpy import CustomConstraint + >>> free_x_a = CustomConstraint(lambda a_: not a_.has(x)) + >>> free_x_b = CustomConstraint(lambda b_: not b_.has(x)) + + Now create the rule replacer with the constraints: + + >>> replacer = Replacer(common_constraints=[free_x_a, free_x_b]) + + Add the matching rule: + + >>> replacer.add(Equality(a_*x + b_, 0), -b_/a_) + + Let's try it: + + >>> replacer.replace(Equality(3*x + 4, 0)) + -4/3 + + Notice that it will not match equations expressed with other patterns: + + >>> eq = Equality(3*x, 4) + >>> replacer.replace(eq) + Eq(3*x, 4) + + In order to extend the matching patterns, define another one (we also need + to clear the cache, because the previous result has already been memorized + and the pattern matcher will not iterate again if given the same expression) + + >>> replacer.add(Equality(a_*x, b_), b_/a_) + >>> replacer._matcher.clear() + >>> replacer.replace(eq) + 4/3 + """ + + def __init__(self, common_constraints: list = [], lambdify: bool = False, info: bool = False): + self._matcher = matchpy.ManyToOneMatcher() + self._common_constraint = common_constraints + self._lambdify = lambdify + self._info = info + self._wildcards: Dict[str, Wildcard] = {} + + def _get_lambda(self, lambda_str: str) -> Callable[..., Expr]: + exec("from sympy import *") + return eval(lambda_str, locals()) + + def _get_custom_constraint(self, constraint_expr: Expr, condition_template: str) -> Callable[..., Expr]: + wilds = [x.name for x in constraint_expr.atoms(_WildAbstract)] + lambdaargs = ', '.join(wilds) + fullexpr = _get_srepr(constraint_expr) + condition = condition_template.format(fullexpr) + return matchpy.CustomConstraint( + self._get_lambda(f"lambda {lambdaargs}: ({condition})")) + + def _get_custom_constraint_nonfalse(self, constraint_expr: Expr) -> Callable[..., Expr]: + return self._get_custom_constraint(constraint_expr, "({}) != False") + + def _get_custom_constraint_true(self, constraint_expr: Expr) -> Callable[..., Expr]: + return self._get_custom_constraint(constraint_expr, "({}) == True") + + def add(self, expr: Expr, replacement, conditions_true: List[Expr] = [], + conditions_nonfalse: List[Expr] = [], info: Any = None) -> None: + expr = _sympify(expr) + replacement = _sympify(replacement) + constraints = self._common_constraint[:] + constraint_conditions_true = [ + self._get_custom_constraint_true(cond) for cond in conditions_true] + constraint_conditions_nonfalse = [ + self._get_custom_constraint_nonfalse(cond) for cond in conditions_nonfalse] + constraints.extend(constraint_conditions_true) + constraints.extend(constraint_conditions_nonfalse) + pattern = matchpy.Pattern(expr, *constraints) + if self._lambdify: + lambda_str = f"lambda {', '.join((x.name for x in expr.atoms(_WildAbstract)))}: {_get_srepr(replacement)}" + lambda_expr = self._get_lambda(lambda_str) + replacement = lambda_expr + else: + self._wildcards.update({str(i): i for i in expr.atoms(Wildcard)}) + if self._info: + replacement = ReplacementInfo(replacement, info) + self._matcher.add(pattern, replacement) + + def replace(self, expression, max_count: int = -1): + # This method partly rewrites the .replace method of ManyToOneReplacer + # in MatchPy. + # License: https://github.com/HPAC/matchpy/blob/master/LICENSE + infos = [] + replaced = True + replace_count = 0 + while replaced and (max_count < 0 or replace_count < max_count): + replaced = False + for subexpr, pos in matchpy.preorder_iter_with_position(expression): + try: + replacement_data, subst = next(iter(self._matcher.match(subexpr))) + if self._info: + replacement = replacement_data.replacement + infos.append(replacement_data.info) + else: + replacement = replacement_data + + if self._lambdify: + result = replacement(**subst) + else: + result = replacement.xreplace({self._wildcards[k]: v for k, v in subst.items()}) + + expression = matchpy.functions.replace(expression, pos, result) + replaced = True + break + except StopIteration: + pass + replace_count += 1 + if self._info: + return expression, infos + else: + return expression diff --git a/.venv/lib/python3.13/site-packages/sympy/utilities/memoization.py b/.venv/lib/python3.13/site-packages/sympy/utilities/memoization.py new file mode 100644 index 0000000000000000000000000000000000000000..b638dfe244628096108ea689b664782f6538b7b8 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/utilities/memoization.py @@ -0,0 +1,76 @@ +from functools import wraps + + +def recurrence_memo(initial): + """ + Memo decorator for sequences defined by recurrence + + Examples + ======== + + >>> from sympy.utilities.memoization import recurrence_memo + >>> @recurrence_memo([1]) # 0! = 1 + ... def factorial(n, prev): + ... return n * prev[-1] + >>> factorial(4) + 24 + >>> factorial(3) # use cache values + 6 + >>> factorial.cache_length() # cache length can be obtained + 5 + >>> factorial.fetch_item(slice(2, 4)) + [2, 6] + + """ + cache = initial + + def decorator(f): + @wraps(f) + def g(n): + L = len(cache) + if n < L: + return cache[n] + for i in range(L, n + 1): + cache.append(f(i, cache)) + return cache[-1] + g.cache_length = lambda: len(cache) + g.fetch_item = lambda x: cache[x] + return g + return decorator + + +def assoc_recurrence_memo(base_seq): + """ + Memo decorator for associated sequences defined by recurrence starting from base + + base_seq(n) -- callable to get base sequence elements + + XXX works only for Pn0 = base_seq(0) cases + XXX works only for m <= n cases + """ + + cache = [] + + def decorator(f): + @wraps(f) + def g(n, m): + L = len(cache) + if n < L: + return cache[n][m] + + for i in range(L, n + 1): + # get base sequence + F_i0 = base_seq(i) + F_i_cache = [F_i0] + cache.append(F_i_cache) + + # XXX only works for m <= n cases + # generate assoc sequence + for j in range(1, i + 1): + F_ij = f(i, j, cache) + F_i_cache.append(F_ij) + + return cache[n][m] + + return g + return decorator diff --git a/.venv/lib/python3.13/site-packages/sympy/utilities/misc.py b/.venv/lib/python3.13/site-packages/sympy/utilities/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..741b983f03272890b22f2545f67b141767507634 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/utilities/misc.py @@ -0,0 +1,564 @@ +"""Miscellaneous stuff that does not really fit anywhere else.""" + +from __future__ import annotations + +import operator +import sys +import os +import re as _re +import struct +from textwrap import fill, dedent + + +class Undecidable(ValueError): + # an error to be raised when a decision cannot be made definitively + # where a definitive answer is needed + pass + + +def filldedent(s, w=70, **kwargs): + """ + Strips leading and trailing empty lines from a copy of ``s``, then dedents, + fills and returns it. + + Empty line stripping serves to deal with docstrings like this one that + start with a newline after the initial triple quote, inserting an empty + line at the beginning of the string. + + Additional keyword arguments will be passed to ``textwrap.fill()``. + + See Also + ======== + strlines, rawlines + + """ + return '\n' + fill(dedent(str(s)).strip('\n'), width=w, **kwargs) + + +def strlines(s, c=64, short=False): + """Return a cut-and-pastable string that, when printed, is + equivalent to the input. The lines will be surrounded by + parentheses and no line will be longer than c (default 64) + characters. If the line contains newlines characters, the + `rawlines` result will be returned. If ``short`` is True + (default is False) then if there is one line it will be + returned without bounding parentheses. + + Examples + ======== + + >>> from sympy.utilities.misc import strlines + >>> q = 'this is a long string that should be broken into shorter lines' + >>> print(strlines(q, 40)) + ( + 'this is a long string that should be b' + 'roken into shorter lines' + ) + >>> q == ( + ... 'this is a long string that should be b' + ... 'roken into shorter lines' + ... ) + True + + See Also + ======== + filldedent, rawlines + """ + if not isinstance(s, str): + raise ValueError('expecting string input') + if '\n' in s: + return rawlines(s) + q = '"' if repr(s).startswith('"') else "'" + q = (q,)*2 + if '\\' in s: # use r-string + m = '(\nr%s%%s%s\n)' % q + j = '%s\nr%s' % q + c -= 3 + else: + m = '(\n%s%%s%s\n)' % q + j = '%s\n%s' % q + c -= 2 + out = [] + while s: + out.append(s[:c]) + s=s[c:] + if short and len(out) == 1: + return (m % out[0]).splitlines()[1] # strip bounding (\n...\n) + return m % j.join(out) + + +def rawlines(s): + """Return a cut-and-pastable string that, when printed, is equivalent + to the input. Use this when there is more than one line in the + string. The string returned is formatted so it can be indented + nicely within tests; in some cases it is wrapped in the dedent + function which has to be imported from textwrap. + + Examples + ======== + + Note: because there are characters in the examples below that need + to be escaped because they are themselves within a triple quoted + docstring, expressions below look more complicated than they would + be if they were printed in an interpreter window. + + >>> from sympy.utilities.misc import rawlines + >>> from sympy import TableForm + >>> s = str(TableForm([[1, 10]], headings=(None, ['a', 'bee']))) + >>> print(rawlines(s)) + ( + 'a bee\\n' + '-----\\n' + '1 10 ' + ) + >>> print(rawlines('''this + ... that''')) + dedent('''\\ + this + that''') + + >>> print(rawlines('''this + ... that + ... ''')) + dedent('''\\ + this + that + ''') + + >>> s = \"\"\"this + ... is a triple ''' + ... \"\"\" + >>> print(rawlines(s)) + dedent(\"\"\"\\ + this + is a triple ''' + \"\"\") + + >>> print(rawlines('''this + ... that + ... ''')) + ( + 'this\\n' + 'that\\n' + ' ' + ) + + See Also + ======== + filldedent, strlines + """ + lines = s.split('\n') + if len(lines) == 1: + return repr(lines[0]) + triple = ["'''" in s, '"""' in s] + if any(li.endswith(' ') for li in lines) or '\\' in s or all(triple): + rv = [] + # add on the newlines + trailing = s.endswith('\n') + last = len(lines) - 1 + for i, li in enumerate(lines): + if i != last or trailing: + rv.append(repr(li + '\n')) + else: + rv.append(repr(li)) + return '(\n %s\n)' % '\n '.join(rv) + else: + rv = '\n '.join(lines) + if triple[0]: + return 'dedent("""\\\n %s""")' % rv + else: + return "dedent('''\\\n %s''')" % rv + +ARCH = str(struct.calcsize('P') * 8) + "-bit" + + +# XXX: PyPy does not support hash randomization +HASH_RANDOMIZATION = getattr(sys.flags, 'hash_randomization', False) + +_debug_tmp: list[str] = [] +_debug_iter = 0 + +def debug_decorator(func): + """If SYMPY_DEBUG is True, it will print a nice execution tree with + arguments and results of all decorated functions, else do nothing. + """ + from sympy import SYMPY_DEBUG + + if not SYMPY_DEBUG: + return func + + def maketree(f, *args, **kw): + global _debug_tmp, _debug_iter + oldtmp = _debug_tmp + _debug_tmp = [] + _debug_iter += 1 + + def tree(subtrees): + def indent(s, variant=1): + x = s.split("\n") + r = "+-%s\n" % x[0] + for a in x[1:]: + if a == "": + continue + if variant == 1: + r += "| %s\n" % a + else: + r += " %s\n" % a + return r + if len(subtrees) == 0: + return "" + f = [] + for a in subtrees[:-1]: + f.append(indent(a)) + f.append(indent(subtrees[-1], 2)) + return ''.join(f) + + # If there is a bug and the algorithm enters an infinite loop, enable the + # following lines. It will print the names and parameters of all major functions + # that are called, *before* they are called + #from functools import reduce + #print("%s%s %s%s" % (_debug_iter, reduce(lambda x, y: x + y, \ + # map(lambda x: '-', range(1, 2 + _debug_iter))), f.__name__, args)) + + r = f(*args, **kw) + + _debug_iter -= 1 + s = "%s%s = %s\n" % (f.__name__, args, r) + if _debug_tmp != []: + s += tree(_debug_tmp) + _debug_tmp = oldtmp + _debug_tmp.append(s) + if _debug_iter == 0: + print(_debug_tmp[0]) + _debug_tmp = [] + return r + + def decorated(*args, **kwargs): + return maketree(func, *args, **kwargs) + + return decorated + + +def debug(*args): + """ + Print ``*args`` if SYMPY_DEBUG is True, else do nothing. + """ + from sympy import SYMPY_DEBUG + if SYMPY_DEBUG: + print(*args, file=sys.stderr) + + +def debugf(string, args): + """ + Print ``string%args`` if SYMPY_DEBUG is True, else do nothing. This is + intended for debug messages using formatted strings. + """ + from sympy import SYMPY_DEBUG + if SYMPY_DEBUG: + print(string%args, file=sys.stderr) + + +def find_executable(executable, path=None): + """Try to find 'executable' in the directories listed in 'path' (a + string listing directories separated by 'os.pathsep'; defaults to + os.environ['PATH']). Returns the complete filename or None if not + found + """ + from .exceptions import sympy_deprecation_warning + sympy_deprecation_warning( + """ + sympy.utilities.misc.find_executable() is deprecated. Use the standard + library shutil.which() function instead. + """, + deprecated_since_version="1.7", + active_deprecations_target="deprecated-find-executable", + ) + if path is None: + path = os.environ['PATH'] + paths = path.split(os.pathsep) + extlist = [''] + if os.name == 'os2': + (base, ext) = os.path.splitext(executable) + # executable files on OS/2 can have an arbitrary extension, but + # .exe is automatically appended if no dot is present in the name + if not ext: + executable = executable + ".exe" + elif sys.platform == 'win32': + pathext = os.environ['PATHEXT'].lower().split(os.pathsep) + (base, ext) = os.path.splitext(executable) + if ext.lower() not in pathext: + extlist = pathext + for ext in extlist: + execname = executable + ext + if os.path.isfile(execname): + return execname + else: + for p in paths: + f = os.path.join(p, execname) + if os.path.isfile(f): + return f + + return None + + +def func_name(x, short=False): + """Return function name of `x` (if defined) else the `type(x)`. + If short is True and there is a shorter alias for the result, + return the alias. + + Examples + ======== + + >>> from sympy.utilities.misc import func_name + >>> from sympy import Matrix + >>> from sympy.abc import x + >>> func_name(Matrix.eye(3)) + 'MutableDenseMatrix' + >>> func_name(x < 1) + 'StrictLessThan' + >>> func_name(x < 1, short=True) + 'Lt' + """ + alias = { + 'GreaterThan': 'Ge', + 'StrictGreaterThan': 'Gt', + 'LessThan': 'Le', + 'StrictLessThan': 'Lt', + 'Equality': 'Eq', + 'Unequality': 'Ne', + } + typ = type(x) + if str(typ).startswith(">> from sympy.utilities.misc import _replace + >>> f = _replace(dict(foo='bar', d='t')) + >>> f('food') + 'bart' + >>> f = _replace({}) + >>> f('food') + 'food' + """ + if not reps: + return lambda x: x + D = lambda match: reps[match.group(0)] + pattern = _re.compile("|".join( + [_re.escape(k) for k, v in reps.items()]), _re.MULTILINE) + return lambda string: pattern.sub(D, string) + + +def replace(string, *reps): + """Return ``string`` with all keys in ``reps`` replaced with + their corresponding values, longer strings first, irrespective + of the order they are given. ``reps`` may be passed as tuples + or a single mapping. + + Examples + ======== + + >>> from sympy.utilities.misc import replace + >>> replace('foo', {'oo': 'ar', 'f': 'b'}) + 'bar' + >>> replace("spamham sha", ("spam", "eggs"), ("sha","md5")) + 'eggsham md5' + + There is no guarantee that a unique answer will be + obtained if keys in a mapping overlap (i.e. are the same + length and have some identical sequence at the + beginning/end): + + >>> reps = [ + ... ('ab', 'x'), + ... ('bc', 'y')] + >>> replace('abc', *reps) in ('xc', 'ay') + True + + References + ========== + + .. [1] https://stackoverflow.com/questions/6116978/how-to-replace-multiple-substrings-of-a-string + """ + if len(reps) == 1: + kv = reps[0] + if isinstance(kv, dict): + reps = kv + else: + return string.replace(*kv) + else: + reps = dict(reps) + return _replace(reps)(string) + + +def translate(s, a, b=None, c=None): + """Return ``s`` where characters have been replaced or deleted. + + SYNTAX + ====== + + translate(s, None, deletechars): + all characters in ``deletechars`` are deleted + translate(s, map [,deletechars]): + all characters in ``deletechars`` (if provided) are deleted + then the replacements defined by map are made; if the keys + of map are strings then the longer ones are handled first. + Multicharacter deletions should have a value of ''. + translate(s, oldchars, newchars, deletechars) + all characters in ``deletechars`` are deleted + then each character in ``oldchars`` is replaced with the + corresponding character in ``newchars`` + + Examples + ======== + + >>> from sympy.utilities.misc import translate + >>> abc = 'abc' + >>> translate(abc, None, 'a') + 'bc' + >>> translate(abc, {'a': 'x'}, 'c') + 'xb' + >>> translate(abc, {'abc': 'x', 'a': 'y'}) + 'x' + + >>> translate('abcd', 'ac', 'AC', 'd') + 'AbC' + + There is no guarantee that a unique answer will be + obtained if keys in a mapping overlap are the same + length and have some identical sequences at the + beginning/end: + + >>> translate(abc, {'ab': 'x', 'bc': 'y'}) in ('xc', 'ay') + True + """ + + mr = {} + if a is None: + if c is not None: + raise ValueError('c should be None when a=None is passed, instead got %s' % c) + if b is None: + return s + c = b + a = b = '' + else: + if isinstance(a, dict): + short = {} + for k in list(a.keys()): + if len(k) == 1 and len(a[k]) == 1: + short[k] = a.pop(k) + mr = a + c = b + if short: + a, b = [''.join(i) for i in list(zip(*short.items()))] + else: + a = b = '' + elif len(a) != len(b): + raise ValueError('oldchars and newchars have different lengths') + + if c: + val = str.maketrans('', '', c) + s = s.translate(val) + s = replace(s, mr) + n = str.maketrans(a, b) + return s.translate(n) + + +def ordinal(num): + """Return ordinal number string of num, e.g. 1 becomes 1st. + """ + # modified from https://codereview.stackexchange.com/questions/41298/producing-ordinal-numbers + n = as_int(num) + k = abs(n) % 100 + if 11 <= k <= 13: + suffix = 'th' + elif k % 10 == 1: + suffix = 'st' + elif k % 10 == 2: + suffix = 'nd' + elif k % 10 == 3: + suffix = 'rd' + else: + suffix = 'th' + return str(n) + suffix + + +def as_int(n, strict=True): + """ + Convert the argument to a builtin integer. + + The return value is guaranteed to be equal to the input. ValueError is + raised if the input has a non-integral value. When ``strict`` is True, this + uses `__index__ `_ + and when it is False it uses ``int``. + + + Examples + ======== + + >>> from sympy.utilities.misc import as_int + >>> from sympy import sqrt, S + + The function is primarily concerned with sanitizing input for + functions that need to work with builtin integers, so anything that + is unambiguously an integer should be returned as an int: + + >>> as_int(S(3)) + 3 + + Floats, being of limited precision, are not assumed to be exact and + will raise an error unless the ``strict`` flag is False. This + precision issue becomes apparent for large floating point numbers: + + >>> big = 1e23 + >>> type(big) is float + True + >>> big == int(big) + True + >>> as_int(big) + Traceback (most recent call last): + ... + ValueError: ... is not an integer + >>> as_int(big, strict=False) + 99999999999999991611392 + + Input that might be a complex representation of an integer value is + also rejected by default: + + >>> one = sqrt(3 + 2*sqrt(2)) - sqrt(2) + >>> int(one) == 1 + True + >>> as_int(one) + Traceback (most recent call last): + ... + ValueError: ... is not an integer + """ + if strict: + try: + if isinstance(n, bool): + raise TypeError + return operator.index(n) + except TypeError: + raise ValueError('%s is not an integer' % (n,)) + else: + try: + result = int(n) + except TypeError: + raise ValueError('%s is not an integer' % (n,)) + if n - result: + raise ValueError('%s is not an integer' % (n,)) + return result diff --git a/.venv/lib/python3.13/site-packages/sympy/utilities/pkgdata.py b/.venv/lib/python3.13/site-packages/sympy/utilities/pkgdata.py new file mode 100644 index 0000000000000000000000000000000000000000..8bf2065759362ee09a252d4736bd612b8d271e72 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/utilities/pkgdata.py @@ -0,0 +1,33 @@ +# This module is deprecated and will be removed. + +import sys +import os +from io import StringIO + +from sympy.utilities.decorator import deprecated + + +@deprecated( + """ + The sympy.utilities.pkgdata module and its get_resource function are + deprecated. Use the stdlib importlib.resources module instead. + """, + deprecated_since_version="1.12", + active_deprecations_target="pkgdata", +) +def get_resource(identifier, pkgname=__name__): + + mod = sys.modules[pkgname] + fn = getattr(mod, '__file__', None) + if fn is None: + raise OSError("%r has no __file__!") + path = os.path.join(os.path.dirname(fn), identifier) + loader = getattr(mod, '__loader__', None) + if loader is not None: + try: + data = loader.get_data(path) + except (OSError, AttributeError): + pass + else: + return StringIO(data.decode('utf-8')) + return open(os.path.normpath(path), 'rb') diff --git a/.venv/lib/python3.13/site-packages/sympy/utilities/pytest.py b/.venv/lib/python3.13/site-packages/sympy/utilities/pytest.py new file mode 100644 index 0000000000000000000000000000000000000000..75494c8e987c7b89bddf18febe09fdc3df2b194e --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/utilities/pytest.py @@ -0,0 +1,12 @@ +""" +.. deprecated:: 1.6 + + sympy.utilities.pytest has been renamed to sympy.testing.pytest. +""" +from sympy.utilities.exceptions import sympy_deprecation_warning + +sympy_deprecation_warning("The sympy.utilities.pytest submodule is deprecated. Use sympy.testing.pytest instead.", + deprecated_since_version="1.6", + active_deprecations_target="deprecated-sympy-utilities-submodules") + +from sympy.testing.pytest import * # noqa:F401,F403 diff --git a/.venv/lib/python3.13/site-packages/sympy/utilities/randtest.py b/.venv/lib/python3.13/site-packages/sympy/utilities/randtest.py new file mode 100644 index 0000000000000000000000000000000000000000..1bd3472ed8a89198d9e78e125f16dae1a56f6bad --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/utilities/randtest.py @@ -0,0 +1,12 @@ +""" +.. deprecated:: 1.6 + + sympy.utilities.randtest has been renamed to sympy.core.random. +""" +from sympy.utilities.exceptions import sympy_deprecation_warning + +sympy_deprecation_warning("The sympy.utilities.randtest submodule is deprecated. Use sympy.core.random instead.", + deprecated_since_version="1.6", + active_deprecations_target="deprecated-sympy-utilities-submodules") + +from sympy.core.random import * # noqa:F401,F403 diff --git a/.venv/lib/python3.13/site-packages/sympy/utilities/runtests.py b/.venv/lib/python3.13/site-packages/sympy/utilities/runtests.py new file mode 100644 index 0000000000000000000000000000000000000000..bb9f760f070be528e8cd51ab38e00bc944dda9ac --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/utilities/runtests.py @@ -0,0 +1,13 @@ +""" +.. deprecated:: 1.6 + + sympy.utilities.runtests has been renamed to sympy.testing.runtests. +""" + +from sympy.utilities.exceptions import sympy_deprecation_warning + +sympy_deprecation_warning("The sympy.utilities.runtests submodule is deprecated. Use sympy.testing.runtests instead.", + deprecated_since_version="1.6", + active_deprecations_target="deprecated-sympy-utilities-submodules") + +from sympy.testing.runtests import * # noqa: F401,F403 diff --git a/.venv/lib/python3.13/site-packages/sympy/utilities/source.py b/.venv/lib/python3.13/site-packages/sympy/utilities/source.py new file mode 100644 index 0000000000000000000000000000000000000000..71692b4aaad07df70b63d7e1eaa86c402ad03c4f --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/utilities/source.py @@ -0,0 +1,40 @@ +""" +This module adds several functions for interactive source code inspection. +""" + + +def get_class(lookup_view): + """ + Convert a string version of a class name to the object. + + For example, get_class('sympy.core.Basic') will return + class Basic located in module sympy.core + """ + if isinstance(lookup_view, str): + mod_name, func_name = get_mod_func(lookup_view) + if func_name != '': + lookup_view = getattr( + __import__(mod_name, {}, {}, ['*']), func_name) + if not callable(lookup_view): + raise AttributeError( + "'%s.%s' is not a callable." % (mod_name, func_name)) + return lookup_view + + +def get_mod_func(callback): + """ + splits the string path to a class into a string path to the module + and the name of the class. + + Examples + ======== + + >>> from sympy.utilities.source import get_mod_func + >>> get_mod_func('sympy.core.basic.Basic') + ('sympy.core.basic', 'Basic') + + """ + dot = callback.rfind('.') + if dot == -1: + return callback, '' + return callback[:dot], callback[dot + 1:] diff --git a/.venv/lib/python3.13/site-packages/sympy/utilities/timeutils.py b/.venv/lib/python3.13/site-packages/sympy/utilities/timeutils.py new file mode 100644 index 0000000000000000000000000000000000000000..1caae825103335f3e8afedeaf741cff2b0414fb2 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/utilities/timeutils.py @@ -0,0 +1,75 @@ +"""Simple tools for timing functions' execution, when IPython is not available. """ + + +import timeit +import math + + +_scales = [1e0, 1e3, 1e6, 1e9] +_units = ['s', 'ms', '\N{GREEK SMALL LETTER MU}s', 'ns'] + + +def timed(func, setup="pass", limit=None): + """Adaptively measure execution time of a function. """ + timer = timeit.Timer(func, setup=setup) + repeat, number = 3, 1 + + for i in range(1, 10): + if timer.timeit(number) >= 0.2: + break + elif limit is not None and number >= limit: + break + else: + number *= 10 + + time = min(timer.repeat(repeat, number)) / number + + if time > 0.0: + order = min(-int(math.floor(math.log10(time)) // 3), 3) + else: + order = 3 + + return (number, time, time*_scales[order], _units[order]) + + +# Code for doing inline timings of recursive algorithms. + +def __do_timings(): + import os + res = os.getenv('SYMPY_TIMINGS', '') + res = [x.strip() for x in res.split(',')] + return set(res) + +_do_timings = __do_timings() +_timestack = None + + +def _print_timestack(stack, level=1): + print('-'*level, '%.2f %s%s' % (stack[2], stack[0], stack[3])) + for s in stack[1]: + _print_timestack(s, level + 1) + + +def timethis(name): + def decorator(func): + if name not in _do_timings: + return func + + def wrapper(*args, **kwargs): + from time import time + global _timestack + oldtimestack = _timestack + _timestack = [func.func_name, [], 0, args] + t1 = time() + r = func(*args, **kwargs) + t2 = time() + _timestack[2] = t2 - t1 + if oldtimestack is not None: + oldtimestack[1].append(_timestack) + _timestack = oldtimestack + else: + _print_timestack(_timestack) + _timestack = None + return r + return wrapper + return decorator diff --git a/.venv/lib/python3.13/site-packages/sympy/utilities/tmpfiles.py b/.venv/lib/python3.13/site-packages/sympy/utilities/tmpfiles.py new file mode 100644 index 0000000000000000000000000000000000000000..81c97efbf9d67cc20f25ff251abc94c2f5bafe06 --- /dev/null +++ b/.venv/lib/python3.13/site-packages/sympy/utilities/tmpfiles.py @@ -0,0 +1,12 @@ +""" +.. deprecated:: 1.6 + + sympy.utilities.tmpfiles has been renamed to sympy.testing.tmpfiles. +""" +from sympy.utilities.exceptions import sympy_deprecation_warning + +sympy_deprecation_warning("The sympy.utilities.tmpfiles submodule is deprecated. Use sympy.testing.tmpfiles instead.", + deprecated_since_version="1.6", + active_deprecations_target="deprecated-sympy-utilities-submodules") + +from sympy.testing.tmpfiles import * # noqa:F401,F403 diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/000097.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/000097.json new file mode 100644 index 0000000000000000000000000000000000000000..e4a5d45d89747fcaa5ea52173ad60597e09405d0 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/000097.json @@ -0,0 +1,30 @@ +{ + "frame_idx": 97, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [ + { + "frame_idx": 97, + "class": "qr code", + "bbox": [ + 300.0, + 388.0, + 401.0, + 462.0 + ], + "score": 0.9517342448234558, + "track_id": "auto-0" + }, + { + "frame_idx": 97, + "class": "box", + "bbox": [ + 27.0, + 86.0, + 1094.0, + 582.0 + ], + "score": 0.9266112446784973, + "track_id": "auto-1" + } + ] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/000214.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/000214.json new file mode 100644 index 0000000000000000000000000000000000000000..161aec9afb7d41c39e074e23de4f038e37a16482 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/000214.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 214, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/000351.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/000351.json new file mode 100644 index 0000000000000000000000000000000000000000..89d3b6c6ba859de6eca7e4e17a23c1077a8ae30a --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/000351.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 351, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/000582.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/000582.json new file mode 100644 index 0000000000000000000000000000000000000000..433593952b21ef24e383bdf114ab4198007bd3a7 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/000582.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 582, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/000644.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/000644.json new file mode 100644 index 0000000000000000000000000000000000000000..14f1163c6246fa3c57ab01e366ed483be85f2e6d --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/000644.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 644, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/000701.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/000701.json new file mode 100644 index 0000000000000000000000000000000000000000..a78f2710c8ff410da15e5225e424266f79b2b6f7 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/000701.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 701, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/000994.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/000994.json new file mode 100644 index 0000000000000000000000000000000000000000..52471b4c79a19283e52185200fbb822c60f5b7e4 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/000994.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 994, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/001055.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/001055.json new file mode 100644 index 0000000000000000000000000000000000000000..d4aa193922443f9d9d1fb29f13c8e013cab6f28e --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/001055.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 1055, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/001110.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/001110.json new file mode 100644 index 0000000000000000000000000000000000000000..c5245261c421882d15cbf9c6f1f61940a5a71f93 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/001110.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 1110, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/001405.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/001405.json new file mode 100644 index 0000000000000000000000000000000000000000..f810c2fac158c9b9b88b6da421f0615bbdc94944 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/001405.json @@ -0,0 +1,18 @@ +{ + "frame_idx": 1405, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [ + { + "frame_idx": 1405, + "class": "box", + "bbox": [ + 2.0, + 248.0, + 160.0, + 465.0 + ], + "score": 0.7260979413986206, + "track_id": "auto-0" + } + ] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/001540.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/001540.json new file mode 100644 index 0000000000000000000000000000000000000000..5cc047c9aa5d0f7ff0af30ee93afee5d57840131 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/001540.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 1540, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/001686.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/001686.json new file mode 100644 index 0000000000000000000000000000000000000000..a1bae9fbf3f7d764412e3e96fca9fceedd164ace --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/001686.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 1686, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/001813.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/001813.json new file mode 100644 index 0000000000000000000000000000000000000000..7eecae032669c64a312c3530e40da349345db435 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/001813.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 1813, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/001956.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/001956.json new file mode 100644 index 0000000000000000000000000000000000000000..a8a27da632953aea07cf2f007732a0e934c2ff12 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/001956.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 1956, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/002045.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/002045.json new file mode 100644 index 0000000000000000000000000000000000000000..b4c48b2be033593d2498038c1e650895035589f6 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/002045.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 2045, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/002100.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/002100.json new file mode 100644 index 0000000000000000000000000000000000000000..e7bdb4b86f363414ce7ea38b47cf55c14a6170a3 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/002100.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 2100, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/002229.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/002229.json new file mode 100644 index 0000000000000000000000000000000000000000..446a45770bdf21c30bdad28e3ce45dc9ec343e5c --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/002229.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 2229, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/002383.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/002383.json new file mode 100644 index 0000000000000000000000000000000000000000..f058fe276b9d22bb19d811cab1bb0b383436258b --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/002383.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 2383, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/002550.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/002550.json new file mode 100644 index 0000000000000000000000000000000000000000..c2976d6709e887eaa4238a9e4e97a7d0d843fcfa --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/002550.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 2550, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/002679.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/002679.json new file mode 100644 index 0000000000000000000000000000000000000000..0d1a81151f1bb68748d8d6a4a6cb6642b256a773 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/002679.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 2679, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/002696.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/002696.json new file mode 100644 index 0000000000000000000000000000000000000000..5ee579676fa6e99e43c5dbcf14c5f7505d2e7ac4 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/002696.json @@ -0,0 +1,30 @@ +{ + "frame_idx": 2696, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [ + { + "frame_idx": 2696, + "class": "qr code", + "bbox": [ + 145.0, + 391.0, + 257.0, + 468.0 + ], + "score": 0.9691669344902039, + "track_id": "auto-0" + }, + { + "frame_idx": 2696, + "class": "box", + "bbox": [ + 56.0, + 94.0, + 961.0, + 576.0 + ], + "score": 0.899407684803009, + "track_id": "auto-1" + } + ] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/002803.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/002803.json new file mode 100644 index 0000000000000000000000000000000000000000..3fc5238a619d83cc7a302b0837f50fc859843a29 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/002803.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 2803, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/003068.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/003068.json new file mode 100644 index 0000000000000000000000000000000000000000..31cce2b30ce63554f3e40174bf85a4186257863e --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/003068.json @@ -0,0 +1,18 @@ +{ + "frame_idx": 3068, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [ + { + "frame_idx": 3068, + "class": "box", + "bbox": [ + 828.0, + 0.0, + 1278.0, + 456.0 + ], + "score": 0.9242427349090576, + "track_id": "auto-0" + } + ] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/003087.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/003087.json new file mode 100644 index 0000000000000000000000000000000000000000..63be3ff97ebed89f3c02dac898c0d24231345ee9 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/003087.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 3087, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/003341.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/003341.json new file mode 100644 index 0000000000000000000000000000000000000000..9ab796ad6dd866ffa53e3089efb895a4e44cec5d --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/003341.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 3341, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/003438.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/003438.json new file mode 100644 index 0000000000000000000000000000000000000000..561f37e9c2c4afdfa22d0b5186c4260ae999936a --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/003438.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 3438, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/003592.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/003592.json new file mode 100644 index 0000000000000000000000000000000000000000..6cdf6e6000c8b2fbaaf3de8f784833b44d358775 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/003592.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 3592, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/003711.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/003711.json new file mode 100644 index 0000000000000000000000000000000000000000..a9e2b2adf671112b09b277054f49dd937415489d --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/003711.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 3711, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/003984.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/003984.json new file mode 100644 index 0000000000000000000000000000000000000000..6a2b817127d994d1d75175ed581d1e1cb98ab5b1 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/003984.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 3984, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/004257.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/004257.json new file mode 100644 index 0000000000000000000000000000000000000000..486730a78922a589d2ece8c7679508120e4623c0 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/004257.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 4257, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/004312.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/004312.json new file mode 100644 index 0000000000000000000000000000000000000000..88c930e8c6c565065b451edf99c92da1b82a5121 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/004312.json @@ -0,0 +1,18 @@ +{ + "frame_idx": 4312, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [ + { + "frame_idx": 4312, + "class": "box", + "bbox": [ + 412.0, + 155.0, + 1280.0, + 593.0 + ], + "score": 0.8329508304595947, + "track_id": "auto-0" + } + ] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/004484.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/004484.json new file mode 100644 index 0000000000000000000000000000000000000000..75351ab271b79e38fdfb7d3c0584a1775f821829 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/004484.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 4484, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/004607.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/004607.json new file mode 100644 index 0000000000000000000000000000000000000000..350316977215c2638e858b35414b673ae8f1b18f --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/004607.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 4607, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/004742.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/004742.json new file mode 100644 index 0000000000000000000000000000000000000000..d1ff3a09bebf6d8ddd5dbb587242c1702a8b37da --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/004742.json @@ -0,0 +1,30 @@ +{ + "frame_idx": 4742, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [ + { + "frame_idx": 4742, + "class": "box", + "bbox": [ + 812.0, + 182.0, + 1280.0, + 507.0 + ], + "score": 0.8834483027458191, + "track_id": "auto-0" + }, + { + "frame_idx": 4742, + "class": "box", + "bbox": [ + 406.0, + 424.0, + 831.0, + 621.0 + ], + "score": 0.5904803276062012, + "track_id": "auto-1" + } + ] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/005016.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/005016.json new file mode 100644 index 0000000000000000000000000000000000000000..67d8c657344c12f48ff2ce41cecbd6fba6f849fc --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/005016.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 5016, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/005153.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/005153.json new file mode 100644 index 0000000000000000000000000000000000000000..81b7ea5c9d07fa4f5bb22136a11d78e5ed0bff73 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/005153.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 5153, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/005295.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/005295.json new file mode 100644 index 0000000000000000000000000000000000000000..c01b50847568c9c27e4b02328bbd540723a9e7a1 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/005295.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 5295, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/005446.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/005446.json new file mode 100644 index 0000000000000000000000000000000000000000..ac24836c35a4f9c9c09d02b737ecdf02a6fd5e69 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/005446.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 5446, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/005850.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/005850.json new file mode 100644 index 0000000000000000000000000000000000000000..03eceb99cd345e087b7fd19fe360921c07a819a6 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/005850.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 5850, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/005915.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/005915.json new file mode 100644 index 0000000000000000000000000000000000000000..2beea3ccffed9667321b806ccd45ce08676e143b --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/005915.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 5915, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/006143.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/006143.json new file mode 100644 index 0000000000000000000000000000000000000000..e3d24a6dcf6abbc7d45b572aa211f6cee9dbeeee --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/006143.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 6143, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/006285.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/006285.json new file mode 100644 index 0000000000000000000000000000000000000000..ed987155b4106bcdb8cb68c9be95fadaf0f3325b --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/006285.json @@ -0,0 +1,30 @@ +{ + "frame_idx": 6285, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [ + { + "frame_idx": 6285, + "class": "qr code", + "bbox": [ + 367.0, + 435.0, + 472.0, + 508.0 + ], + "score": 0.9428966045379639, + "track_id": "auto-0" + }, + { + "frame_idx": 6285, + "class": "box", + "bbox": [ + 1.0, + 189.0, + 784.0, + 647.0 + ], + "score": 0.9114072918891907, + "track_id": "auto-1" + } + ] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/006513.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/006513.json new file mode 100644 index 0000000000000000000000000000000000000000..1717127dc58cfcbb38778ae40606ea3756dbb3e8 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/006513.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 6513, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/006790.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/006790.json new file mode 100644 index 0000000000000000000000000000000000000000..3c2f94ffba4b83eb3c4e885c1036519ab9dba4da --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/006790.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 6790, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/006840.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/006840.json new file mode 100644 index 0000000000000000000000000000000000000000..dcb1276e96893be3ce8f241018aa3133a5412a24 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/006840.json @@ -0,0 +1,42 @@ +{ + "frame_idx": 6840, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [ + { + "frame_idx": 6840, + "class": "box", + "bbox": [ + 836.0, + 70.0, + 1280.0, + 507.0 + ], + "score": 0.9347314238548279, + "track_id": "auto-0" + }, + { + "frame_idx": 6840, + "class": "box", + "bbox": [ + 0.0, + 221.0, + 522.0, + 651.0 + ], + "score": 0.9311841726303101, + "track_id": "auto-1" + }, + { + "frame_idx": 6840, + "class": "box", + "bbox": [ + 526.0, + 417.0, + 883.0, + 611.0 + ], + "score": 0.5787769556045532, + "track_id": "auto-2" + } + ] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/007181.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/007181.json new file mode 100644 index 0000000000000000000000000000000000000000..a57c0da8323567a6b6c26b4651e06ca3bb14a8d3 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/007181.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 7181, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/007247.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/007247.json new file mode 100644 index 0000000000000000000000000000000000000000..2125ddd183fc29144358b94a808e15f01a39f25f --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/007247.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 7247, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/007302.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/007302.json new file mode 100644 index 0000000000000000000000000000000000000000..297c3f415768a95862c101d548937f9b87439676 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/007302.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 7302, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/007494.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/007494.json new file mode 100644 index 0000000000000000000000000000000000000000..63f73fd3afc4b544e654801f700f5a6ad2a1a7b4 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/007494.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 7494, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/007752.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/007752.json new file mode 100644 index 0000000000000000000000000000000000000000..3fb142e70250a2766bb50968b973b150b0bceabd --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/007752.json @@ -0,0 +1,18 @@ +{ + "frame_idx": 7752, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [ + { + "frame_idx": 7752, + "class": "box", + "bbox": [ + 2.0, + 158.0, + 632.0, + 642.0 + ], + "score": 0.9109220504760742, + "track_id": "auto-0" + } + ] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/007882.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/007882.json new file mode 100644 index 0000000000000000000000000000000000000000..a2f364a7d9898516572a8be8a5ccaa5ef26ba778 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/007882.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 7882, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/007928.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/007928.json new file mode 100644 index 0000000000000000000000000000000000000000..e85252fb7926f4d609926205f74c46a09791a585 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/007928.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 7928, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/008274.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/008274.json new file mode 100644 index 0000000000000000000000000000000000000000..8492c7144859c3f7f2896ec7afc20c72e739a3b3 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/008274.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 8274, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/008331.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/008331.json new file mode 100644 index 0000000000000000000000000000000000000000..2fada932158ed44ba4b033a0abd07c29416ecb42 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/008331.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 8331, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/008624.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/008624.json new file mode 100644 index 0000000000000000000000000000000000000000..532774b7a1d7933a6fedd9a7ad9c0f679873bd3c --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/008624.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 8624, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/008761.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/008761.json new file mode 100644 index 0000000000000000000000000000000000000000..9f87d498f0816469808d11536766122683b13fc1 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/008761.json @@ -0,0 +1,18 @@ +{ + "frame_idx": 8761, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [ + { + "frame_idx": 8761, + "class": "box", + "bbox": [ + 61.0, + 5.0, + 705.0, + 453.0 + ], + "score": 0.872507631778717, + "track_id": "auto-0" + } + ] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/009035.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/009035.json new file mode 100644 index 0000000000000000000000000000000000000000..c5a44a2577bfa1c49fcc7b6647c52294302941e6 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/009035.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 9035, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/009170.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/009170.json new file mode 100644 index 0000000000000000000000000000000000000000..5a585aaad8dc6d012ae4f2271bdab7a8f651e511 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/009170.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 9170, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/009465.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/009465.json new file mode 100644 index 0000000000000000000000000000000000000000..f7e9fb9703c27006544653601f8a3f41b32664a4 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/009465.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 9465, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/009520.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/009520.json new file mode 100644 index 0000000000000000000000000000000000000000..53c0592ec705c3b3fc4a79d4bf12ebdddef07043 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/009520.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 9520, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/009873.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/009873.json new file mode 100644 index 0000000000000000000000000000000000000000..3a1a87538e63491cfef789a1c65552f158a33e09 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/009873.json @@ -0,0 +1,18 @@ +{ + "frame_idx": 9873, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [ + { + "frame_idx": 9873, + "class": "box", + "bbox": [ + 693.0, + 84.0, + 1280.0, + 556.0 + ], + "score": 0.9433499574661255, + "track_id": "auto-0" + } + ] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/009936.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/009936.json new file mode 100644 index 0000000000000000000000000000000000000000..8d2808ff1f0a9fee4119533cb900e87a7083e9d5 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/009936.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 9936, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/010097.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/010097.json new file mode 100644 index 0000000000000000000000000000000000000000..1ca64c7bb596defaf1b5f9350b0b3cda240eb8b0 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/010097.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 10097, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/010214.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/010214.json new file mode 100644 index 0000000000000000000000000000000000000000..1c4edd292798438fd38d41b094322ecd1d6c22e9 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/010214.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 10214, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/010351.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/010351.json new file mode 100644 index 0000000000000000000000000000000000000000..7482629ee08d0fc4e1c2107922d30a1905bb008a --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/010351.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 10351, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/010644.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/010644.json new file mode 100644 index 0000000000000000000000000000000000000000..fc36318b1cd7e8e292d717873ebb95f44fa26025 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/010644.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 10644, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/010701.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/010701.json new file mode 100644 index 0000000000000000000000000000000000000000..3fb81a5ebf0b159b3ed5a4db3e4ae3bb389936df --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/010701.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 10701, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/010994.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/010994.json new file mode 100644 index 0000000000000000000000000000000000000000..6e3b82bcea2e8a75abe0ace56b3900f879ab9054 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/010994.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 10994, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/011055.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/011055.json new file mode 100644 index 0000000000000000000000000000000000000000..b989ea003185d42733c9a30f52c01602cd2421ac --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/011055.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 11055, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/011110.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/011110.json new file mode 100644 index 0000000000000000000000000000000000000000..e7c4a905afc4cb6234190674f7995b9c0006b30c --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/011110.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 11110, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/011405.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/011405.json new file mode 100644 index 0000000000000000000000000000000000000000..538c443d21782392e9fadf1ef9986e955ed4bba4 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/011405.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 11405, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +} diff --git a/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/011540.json b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/011540.json new file mode 100644 index 0000000000000000000000000000000000000000..0c3d08a186c36b03dbb0e8b3e46be44c20ee7ca5 --- /dev/null +++ b/video/video_2026-02-18T12-43-09Z/packhouse1_inline3/predictions/011540.json @@ -0,0 +1,5 @@ +{ + "frame_idx": 11540, + "image_key": "video/video_2026-02-18T12-43-09Z/packhouse1_inline3/annotated.mp4", + "predictions": [] +}